1 Star 0 Fork 0

Kxvz/RobustVideoMatting

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
inference_speed_test.py 1.70 KB
一键复制 编辑 原始数据 按行查看 历史
Peter Lin 提交于 2021-08-30 14:01 . Code release
"""
python inference_speed_test.py \
--model-variant mobilenetv3 \
--resolution 1920 1080 \
--downsample-ratio 0.25 \
--precision float32
"""
import argparse
import torch
from tqdm import tqdm
from model.model import MattingNetwork
torch.backends.cudnn.benchmark = True
class InferenceSpeedTest:
def __init__(self):
self.parse_args()
self.init_model()
self.loop()
def parse_args(self):
parser = argparse.ArgumentParser()
parser.add_argument('--model-variant', type=str, required=True)
parser.add_argument('--resolution', type=int, required=True, nargs=2)
parser.add_argument('--downsample-ratio', type=float, required=True)
parser.add_argument('--precision', type=str, default='float32')
parser.add_argument('--disable-refiner', action='store_true')
self.args = parser.parse_args()
def init_model(self):
self.device = 'cuda'
self.precision = {'float32': torch.float32, 'float16': torch.float16}[self.args.precision]
self.model = MattingNetwork(self.args.model_variant)
self.model = self.model.to(device=self.device, dtype=self.precision).eval()
self.model = torch.jit.script(self.model)
self.model = torch.jit.freeze(self.model)
def loop(self):
w, h = self.args.resolution
src = torch.randn((1, 3, h, w), device=self.device, dtype=self.precision)
with torch.no_grad():
rec = None, None, None, None
for _ in tqdm(range(1000)):
fgr, pha, *rec = self.model(src, *rec, self.args.downsample_ratio)
torch.cuda.synchronize()
if __name__ == '__main__':
InferenceSpeedTest()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Kxvz/RobustVideoMatting.git
git@gitee.com:Kxvz/RobustVideoMatting.git
Kxvz
RobustVideoMatting
RobustVideoMatting
master

搜索帮助