代码拉取完成,页面将自动刷新
'''
conv bn merge
'''
from modellist.VGG16 import VGGNet16
import torch
from collections import OrderedDict
from torchvision.models import vgg16_bn
from torch import nn
model = vgg16_bn(pretrained=True).to('cpu')
model.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 10),
).to('cpu')
model_state_dcit=torch.load("weights/best.pth")
new_model_state_dict={}
for k , v in model_state_dcit["model_state_dict"].items():
new_model_state_dict[k]=v.to("cpu")
model.load_state_dict(new_model_state_dict)
""" Functions """
def merge(params, name, layer):
# global variables
global weights, bias
global bn_param
if layer == 'Convolution':
# save weights and bias when meet conv layer
if 'weight' in name:
weights = params.data
bias = torch.zeros(weights.size()[0])
elif 'bias' in name:
bias = params.data
bn_param = {}
elif layer == 'BatchNorm':
# save bn params
bn_param[name.split('.')[-1]] = params.data
# running_var is the last bn param in pytorch
if 'running_var' in name:
# let us merge bn ~
tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']
return weights, bias
return None, None
print("start merging conv and bn")
new_weights=OrderedDict()
inner_product_flag=False
for name,params in new_model_state_dict.items():
if len(params.size()) == 4:
_, _ = merge(params, name, 'Convolution')
prev_layer = name
elif len(params.size()) == 1 and not inner_product_flag:
w, b = merge(params, name, 'BatchNorm')
if w is not None:
new_weights[prev_layer] = w
new_weights[prev_layer.replace('weight', 'bias')] = b
else:
# inner product layer
# if meet inner product layer,
# the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
new_weights[name] = params
inner_product_flag = True
print('Aligning weight names...')
pytorch_net_key_list = list(model.state_dict().keys())
new_weights_key_list = list(new_weights.keys())
assert len(pytorch_net_key_list) == len(new_weights_key_list)
for index in range(len(pytorch_net_key_list)):
new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])
SAVE=True
# save new weights
if SAVE:
torch.save(new_weights, "./weights/merge_best.pth")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。