1 Star 0 Fork 0

ylyhappy/dl

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
predict.py 1.09 KB
一键复制 编辑 原始数据 按行查看 历史
ylyhappy 提交于 2022-01-01 04:32 . lenet
# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
# 数据预处理
transform = transforms.Compose(
[transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
im = Image.open('test.jpg')
im = transform(im) # [C, H, W]
im = torch.unsqueeze(im, dim=0) # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width]
# 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))
# 预测
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# with torch.no_grad():
# outputs = net(im)
# predict = torch.max(outputs, dim=1)[1].data.numpy()
with torch.no_grad():
outputs = net(im)
predict = torch.softmax(outputs, dim=1)
print(predict)
# print(classes[int(predict)])
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ylyhappy/dl.git
git@gitee.com:ylyhappy/dl.git
ylyhappy
dl
dl
master

搜索帮助