1 Star 0 Fork 0

CaixyPromise/C++TensorRT人流检测与计数系统

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
export.patch 9.56 KB
一键复制 编辑 原始数据 按行查看 历史
CaixyPromise 提交于 2023-10-03 21:00 . 添加yolo补丁和分支
From 0ab0d40b5874791700720282a20259b2b404c984 Mon Sep 17 00:00:00 2001
From: Tyler Zhu <tylerz@nvidia.com>
Date: Thu, 2 Jun 2022 10:34:21 +0800
Subject: [PATCH] Enable onnx export with decode plugin
---
export.py | 170 ++++++++++++++++++++++++++++++++-----------------
models/yolo.py | 26 ++------
2 files changed, 119 insertions(+), 77 deletions(-)
diff --git a/export.py b/export.py
index 72e170a..f7a5572 100644
--- a/export.py
+++ b/export.py
@@ -111,62 +111,115 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
# YOLOv5 ONNX export
- try:
- check_requirements(('onnx',))
- import onnx
-
- LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
- f = file.with_suffix('.onnx')
-
- torch.onnx.export(
- model,
- im,
- f,
- verbose=False,
- opset_version=opset,
- training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
- do_constant_folding=not train,
- input_names=['images'],
- output_names=['output'],
- dynamic_axes={
- 'images': {
- 0: 'batch',
- 2: 'height',
- 3: 'width'}, # shape(1,3,640,640)
- 'output': {
- 0: 'batch',
- 1: 'anchors'} # shape(1,25200,85)
- } if dynamic else None)
-
- # Checks
- model_onnx = onnx.load(f) # load onnx model
- onnx.checker.check_model(model_onnx) # check onnx model
-
- # Metadata
- d = {'stride': int(max(model.stride)), 'names': model.names}
- for k, v in d.items():
- meta = model_onnx.metadata_props.add()
- meta.key, meta.value = k, str(v)
- onnx.save(model_onnx, f)
+ # try:
+ check_requirements(('onnx',))
+ import onnx
+
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
+ f = file.with_suffix('.onnx')
+ print(train)
+ torch.onnx.export(
+ model,
+ im,
+ f,
+ verbose=False,
+ opset_version=opset,
+ training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
+ do_constant_folding=not train,
+ input_names=['images'],
+ output_names=['p3', 'p4', 'p5'],
+ dynamic_axes={
+ 'images': {
+ 0: 'batch',
+ 2: 'height',
+ 3: 'width'}, # shape(1,3,640,640)
+ 'p3': {
+ 0: 'batch',
+ 2: 'height',
+ 3: 'width'}, # shape(1,25200,4)
+ 'p4': {
+ 0: 'batch',
+ 2: 'height',
+ 3: 'width'},
+ 'p5': {
+ 0: 'batch',
+ 2: 'height',
+ 3: 'width'}
+ } if dynamic else None)
- # Simplify
- if simplify:
- try:
- check_requirements(('onnx-simplifier',))
- import onnxsim
-
- LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
- model_onnx, check = onnxsim.simplify(model_onnx,
- dynamic_input_shape=dynamic,
- input_shapes={'images': list(im.shape)} if dynamic else None)
- assert check, 'assert check failed'
- onnx.save(model_onnx, f)
- except Exception as e:
- LOGGER.info(f'{prefix} simplifier failure: {e}')
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
- return f
- except Exception as e:
- LOGGER.info(f'{prefix} export failure: {e}')
+ # Checks
+ model_onnx = onnx.load(f) # load onnx model
+ onnx.checker.check_model(model_onnx) # check onnx model
+
+ # Simplify
+ if simplify:
+ # try:
+ check_requirements(('onnx-simplifier',))
+ import onnxsim
+
+ LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
+ model_onnx, check = onnxsim.simplify(model_onnx,
+ dynamic_input_shape=dynamic,
+ input_shapes={'images': list(im.shape)} if dynamic else None)
+ assert check, 'assert check failed'
+ onnx.save(model_onnx, f)
+ # except Exception as e:
+ # LOGGER.info(f'{prefix} simplifier failure: {e}')
+
+ # add yolov5_decoding:
+ import onnx_graphsurgeon as onnx_gs
+ import numpy as np
+ yolo_graph = onnx_gs.import_onnx(model_onnx)
+ p3 = yolo_graph.outputs[0]
+ p4 = yolo_graph.outputs[1]
+ p5 = yolo_graph.outputs[2]
+ decode_out_0 = onnx_gs.Variable(
+ "DecodeNumDetection",
+ dtype=np.int32
+ )
+ decode_out_1 = onnx_gs.Variable(
+ "DecodeDetectionBoxes",
+ dtype=np.float32
+ )
+ decode_out_2 = onnx_gs.Variable(
+ "DecodeDetectionScores",
+ dtype=np.float32
+ )
+ decode_out_3 = onnx_gs.Variable(
+ "DecodeDetectionClasses",
+ dtype=np.int32
+ )
+
+ decode_attrs = dict()
+
+ decode_attrs["max_stride"] = int(max(model.stride))
+ decode_attrs["num_classes"] = model.model[-1].nc
+ decode_attrs["anchors"] = [float(v) for v in [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]]
+ decode_attrs["prenms_score_threshold"] = 0.25
+
+ decode_plugin = onnx_gs.Node(
+ op="YoloLayer_TRT",
+ name="YoloLayer",
+ inputs=[p3, p4, p5],
+ outputs=[decode_out_0, decode_out_1, decode_out_2, decode_out_3],
+ attrs=decode_attrs
+ )
+
+ yolo_graph.nodes.append(decode_plugin)
+ yolo_graph.outputs = decode_plugin.outputs
+ yolo_graph.cleanup().toposort()
+ model_onnx = onnx_gs.export_onnx(yolo_graph)
+
+ d = {'stride': int(max(model.stride)), 'names': model.names}
+ for k, v in d.items():
+ meta = model_onnx.metadata_props.add()
+ meta.key, meta.value = k, str(v)
+
+ onnx.save(model_onnx, f)
+ LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
+ return f
+ # except Exception as e:
+ # LOGGER.info(f'{prefix} export failure: {e}')
def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
@@ -488,7 +541,7 @@ def run(
assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
nc, names = model.nc, model.names # number of classes, class names
-
+
# Checks
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'
@@ -499,6 +552,7 @@ def run(
im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
# Update model
+ import torch.nn as nn
if half and not coreml and not xml:
im, model = im.half(), model.half() # to FP16
model.train() if train else model.eval() # training mode = no Detect() layer grid construction
@@ -507,7 +561,9 @@ def run(
m.inplace = inplace
m.onnx_dynamic = dynamic
m.export = True
-
+ elif isinstance(m, nn.Upsample):
+ print(m)
+
for _ in range(2):
y = model(im) # dry runs
shape = tuple(y[0].shape) # model output shape
diff --git a/models/yolo.py b/models/yolo.py
index 02660e6..c810745 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -55,29 +55,15 @@ class Detect(nn.Module):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
- bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
- x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
-
- if not self.training: # inference
- if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
- self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
-
- y = x[i].sigmoid()
- if self.inplace:
- y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
- y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
- else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
- xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
- xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
- wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
- y = torch.cat((xy, wh, conf), 4)
- z.append(y.view(bs, -1, self.no))
-
- return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
+ y = x[i].sigmoid()
+ z.append(y)
+ return z
def _make_grid(self, nx=20, ny=20, i=0):
d = self.anchors[i].device
- t = self.anchors[i].dtype
+ # t = self.anchors[i].dtype
+ # TODO(tylerz) hard-code data type to int
+ t = torch.int32
shape = 1, self.na, ny, nx, 2 # grid shape
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
--
2.36.0
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
C++
1
https://gitee.com/caixypromise/learning-tensor-rt_person.git
git@gitee.com:caixypromise/learning-tensor-rt_person.git
caixypromise
learning-tensor-rt_person
C++TensorRT人流检测与计数系统
main

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385