1 Star 0 Fork 2

13606799717/YoloV3物体检测

forked from cangye/YoloV3物体检测 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
fusemodel.mobilenet.py 3.70 KB
一键复制 编辑 原始数据 按行查看 历史
YUZIYE 提交于 2022-02-18 10:24 . add mobilenet
from torch.quantization import fuse_modules
from utils.model2 import YoloModel, ConvBNReLU, QInvertedResidual, YoloLayer1
import torch
import tqdm
class YoloModelFuse(YoloModel):
def fuse_model(self):
for m in self.modules():
if type(m) == ConvBNReLU:
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == QInvertedResidual:
m.fuse_model()
def forward(self, x):
B = 1
h0 = self.base0(x)
h1 = self.base1(h0)
h2 = self.base2(h1)
y2, cat1 = self.yolo2(h2)
h1 = torch.cat([h1, cat1], dim=1)
y1, cat0 = self.yolo1(h1)
h0 = torch.cat([h0, cat0], dim=1)
y0 = self.yolo0(h0)
y0 = y0.reshape([B, 3, 85, 52, 52]).permute(0, 1, 3, 4, 2)
y1 = y1.reshape([B, 3, 85, 26, 26]).permute(0, 1, 3, 4, 2)
y2 = y2.reshape([B, 3, 85, 13, 13]).permute(0, 1, 3, 4, 2)
anch0 = torch.tensor([[10,13], [16,30], [33,23]], dtype=torch.float32, device=x.device).view(1, -1, 1, 1, 2)
anch1 = torch.tensor([[30,61], [62,45], [59,119]], dtype=torch.float32, device=x.device).view(1, -1, 1, 1, 2)
anch2 = torch.tensor([[116,90], [156,198], [373,326]], dtype=torch.float32, device=x.device).view(1, -1, 1, 1, 2)
yv, xv = torch.meshgrid(torch.arange(52), torch.arange(52))
grid0 = torch.stack([xv, yv], 2).reshape((1, 1, 52, 52, 2)).float().to(x.device)
yv, xv = torch.meshgrid(torch.arange(26), torch.arange(26))
grid1 = torch.stack([xv, yv], 2).reshape((1, 1, 26, 26, 2)).float().to(x.device)
yv, xv = torch.meshgrid(torch.arange(13), torch.arange(13))
grid2 = torch.stack([xv, yv], 2).reshape((1, 1, 13, 13, 2)).float().to(x.device)
y0[..., 0:2] = (y0[..., 0:2].sigmoid() + grid0) * 8 # xy
y0[..., 2:4] = torch.exp(y0[..., 2:4]) * anch0 # wh
y0[..., 4:] = y0[..., 4:].sigmoid()
y0 = y0.reshape(B, -1, 85)
y1[..., 0:2] = (y1[..., 0:2].sigmoid() + grid1) * 16 # xy
y1[..., 2:4] = torch.exp(y1[..., 2:4]) * anch1 # wh
y1[..., 4:] = y1[..., 4:].sigmoid()
y1 = y1.reshape(B, -1, 85)
y2[..., 0:2] = (y2[..., 0:2].sigmoid() + grid2) * 32 # xy
y2[..., 2:4] = torch.exp(y2[..., 2:4]) * anch2 # wh
y2[..., 4:] = y2[..., 4:].sigmoid()
y2 = y2.reshape(B, -1, 85)
y = torch.cat([y0, y1, y2], dim=1)
return y
device = torch.device("cpu")
model = YoloModelFuse()
model.load_state_dict(torch.load("ckpt/23.pt", map_location=device))
model.eval()
model.fuse_model()
torch.jit.save(torch.jit.script(model), "ckpt/mobilenet.jit")
#model = torch.jit.script_if_tracing(model)
input_names = ["image"]
output_names = ["output"]
dummy_input = torch.randn([1, 3, 416, 416])
torch.onnx.export(model, dummy_input,
"ckpt/mobilenet.yolo.onnx",
verbose=True, input_names=input_names,
output_names=output_names, opset_version=12, ##opset has to be set to 12
do_constant_folding=True,
use_external_data_format=False,)
#from onnxruntime.transformers.onnx_model import OnnxModel
#import onnx
#import onnxoptimizer
#
#def has_same_value(val_one,val_two):
# if val_one.raw_data == val_two.raw_data:
# return True
# else:
# return False
## 有些重复的算子需要去除
#path = f"ckpt/mobilenet.yolo.onnx"
#output_path = f"ckpt/mobilenet2.yolo.onnx"
#onnx_model = onnx.load(path)
#passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
#optimized_model = onnxoptimizer.optimize(onnx_model)
#
#onnx.save(optimized_model, output_path)
#
#
#model = YoloLayer1(256, 256, True)
#model.eval()
#torch.save(model.state_dict(), "ckpt/yolo2.pt")
#torch.jit.save(torch.jit.script(model), "ckpt/yolo.pt")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhaoshifu111/yolo-v3-object-detection.git
git@gitee.com:zhaoshifu111/yolo-v3-object-detection.git
zhaoshifu111
yolo-v3-object-detection
YoloV3物体检测
master

搜索帮助