代码拉取完成,页面将自动刷新
同步操作将从 策略定制/BackgroundMattingV2 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
"""
Export MattingRefine as ONNX format
Example:
python export_onnx.py \
--model-type mattingrefine \
--model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--onnx-opset-version 11 \
--onnx-constant-folding \
--precision float32 \
--output "model.onnx" \
--validate
"""
import argparse
import torch
from model import MattingBase, MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Export ONNX')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--onnx-verbose', type=bool, default=True)
parser.add_argument('--onnx-opset-version', type=int, default=12)
parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
parser.add_argument('--validate', action='store_true')
parser.add_argument('--output', type=str, required=True)
args = parser.parse_args()
# --------------- Main ---------------
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size,
refine_patch_crop_method='roi_align',
refine_patch_replace_method='scatter_element')
model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
model.eval().to(precision).to(args.device)
# Dummy Inputs
src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
# Export ONNX
if args.model_type == 'mattingbase':
input_names=['src', 'bgr']
output_names = ['pha', 'fgr', 'err', 'hid']
if args.model_type == 'mattingrefine':
input_names=['src', 'bgr']
output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
torch.onnx.export(
model=model,
args=(src, bgr),
f=args.output,
verbose=args.onnx_verbose,
opset_version=args.onnx_opset_version,
do_constant_folding=args.onnx_constant_folding,
input_names=input_names,
output_names=output_names,
dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
print(f'ONNX model saved at: {args.output}')
# Validation
if args.validate:
import onnxruntime
import numpy as np
print(f'Validating ONNX model.')
# Test with different inputs.
src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
with torch.no_grad():
out_torch = model(src, bgr)
sess = onnxruntime.InferenceSession(args.output)
out_onnx = sess.run(None, {
'src': src.cpu().numpy(),
'bgr': bgr.cpu().numpy()
})
e_max = 0
for a, b, name in zip(out_torch, out_onnx, output_names):
b = torch.as_tensor(b)
e = torch.abs(a.cpu() - b).max()
e_max = max(e_max, e.item())
print(f'"{name}" output differs by maximum of {e}')
if e_max < 0.001:
print('Validation passed.')
else:
raise 'Validation failed.'
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。