3 Star 0 Fork 1

mirrors_Tramac/Fast-SCNN-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
demo.py 2.03 KB
一键复制 编辑 原始数据 按行查看 历史
Boris Testov 提交于 2019-05-13 13:55 . Cpu key (--cpu) for demo.py added
import os
import argparse
import torch
from torchvision import transforms
from models.fast_scnn import get_fast_scnn
from PIL import Image
from utils.visualize import get_color_pallete
parser = argparse.ArgumentParser(
description='Predict segmentation result from a given image')
parser.add_argument('--model', type=str, default='fast_scnn',
help='model name (default: fast_scnn)')
parser.add_argument('--dataset', type=str, default='citys',
help='dataset name (default: citys)')
parser.add_argument('--weights-folder', default='./weights',
help='Directory for saving checkpoint models')
parser.add_argument('--input-pic', type=str,
default='./datasets/citys/leftImg8bit/test/berlin/berlin_000000_000019_leftImg8bit.png',
help='path to the input picture')
parser.add_argument('--outdir', default='./test_result', type=str,
help='path to save the predict result')
parser.add_argument('--cpu', dest='cpu', action='store_true')
parser.set_defaults(cpu=False)
args = parser.parse_args()
def demo():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# output folder
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
# image transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image = Image.open(args.input_pic).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
model = get_fast_scnn(args.dataset, pretrained=True, root=args.weights_folder, map_cpu=args.cpu).to(device)
print('Finished loading model!')
model.eval()
with torch.no_grad():
outputs = model(image)
pred = torch.argmax(outputs[0], 1).squeeze(0).cpu().data.numpy()
mask = get_color_pallete(pred, args.dataset)
outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png'
mask.save(os.path.join(args.outdir, outname))
if __name__ == '__main__':
demo()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mirrors_Tramac/Fast-SCNN-pytorch.git
git@gitee.com:mirrors_Tramac/Fast-SCNN-pytorch.git
mirrors_Tramac
Fast-SCNN-pytorch
Fast-SCNN-pytorch
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385