1 Star 0 Fork 0

Xhao/Pytorch-Unet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 744 Bytes
一键复制 编辑 原始数据 按行查看 历史
Xhao 提交于 2023-01-15 19:24 . init
from PIL import Image
from utils import detect_image
import torch
from model import Unet_vgg
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
name_classes = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair",
"cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
model = Unet_vgg(num_classes=21, pretrained=True).to(device).eval()
checkpoint = torch.load('vgg_pretrain.pth')
model.load_state_dict(checkpoint)
img_path = './VOCdevkit/VOC2007/JPEGImages/2007_009052.jpg'
image = Image.open(img_path)
r_image = detect_image(model, image, device=device)
# r_image.show()
r_image.save('./predict.png')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/firslov/unet.git
git@gitee.com:firslov/unet.git
firslov
unet
Pytorch-Unet
master

搜索帮助