1 Star 1 Fork 0

luotianhang/dog10分类-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model_reborn.py 2.65 KB
一键复制 编辑 原始数据 按行查看 历史
罗天杭 提交于 2021-08-02 01:50 . 20210802
'''
conv bn merge
'''
from modellist.VGG16 import VGGNet16
import torch
from collections import OrderedDict
from torchvision.models import vgg16_bn
from torch import nn
model = vgg16_bn(pretrained=True).to('cpu')
model.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 10),
).to('cpu')
model_state_dcit=torch.load("weights/best.pth")
new_model_state_dict={}
for k , v in model_state_dcit["model_state_dict"].items():
new_model_state_dict[k]=v.to("cpu")
model.load_state_dict(new_model_state_dict)
""" Functions """
def merge(params, name, layer):
# global variables
global weights, bias
global bn_param
if layer == 'Convolution':
# save weights and bias when meet conv layer
if 'weight' in name:
weights = params.data
bias = torch.zeros(weights.size()[0])
elif 'bias' in name:
bias = params.data
bn_param = {}
elif layer == 'BatchNorm':
# save bn params
bn_param[name.split('.')[-1]] = params.data
# running_var is the last bn param in pytorch
if 'running_var' in name:
# let us merge bn ~
tmp = bn_param['weight'] / torch.sqrt(bn_param['running_var'] + 1e-5)
weights = tmp.view(tmp.size()[0], 1, 1, 1) * weights
bias = tmp*(bias - bn_param['running_mean']) + bn_param['bias']
return weights, bias
return None, None
print("start merging conv and bn")
new_weights=OrderedDict()
inner_product_flag=False
for name,params in new_model_state_dict.items():
if len(params.size()) == 4:
_, _ = merge(params, name, 'Convolution')
prev_layer = name
elif len(params.size()) == 1 and not inner_product_flag:
w, b = merge(params, name, 'BatchNorm')
if w is not None:
new_weights[prev_layer] = w
new_weights[prev_layer.replace('weight', 'bias')] = b
else:
# inner product layer
# if meet inner product layer,
# the next bias weight can be misclassified as 'BatchNorm' layer as len(params.size()) == 1
new_weights[name] = params
inner_product_flag = True
print('Aligning weight names...')
pytorch_net_key_list = list(model.state_dict().keys())
new_weights_key_list = list(new_weights.keys())
assert len(pytorch_net_key_list) == len(new_weights_key_list)
for index in range(len(pytorch_net_key_list)):
new_weights[pytorch_net_key_list[index]] = new_weights.pop(new_weights_key_list[index])
SAVE=True
# save new weights
if SAVE:
torch.save(new_weights, "./weights/merge_best.pth")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/luotianhang/luo-tianhang-code.git
git@gitee.com:luotianhang/luo-tianhang-code.git
luotianhang
luo-tianhang-code
dog10分类-pytorch
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385