1 Star 0 Fork 1

琦琦/智能垃圾桶

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 1.69 KB
一键复制 编辑 原始数据 按行查看 历史
4B_pi 提交于 2021-03-16 21:10 . 增加了按钮控制
import torch
from model.model import resnet34
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json
def predict_garbage(img):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
# 加载图片
plt.imshow(img)
# 对图片进行预处理
img = data_transform(img)
# 返回一个新的张量,对输入的既定位置插入维度 1
# [N, C, H, W]
img = torch.unsqueeze(img, dim=0)
try:
json_file = open('utils/garbage_sample_classify.json', 'r')
class_indict = json.load(json_file)
except Exception as ex:
print(ex)
exit(-1)
# 创建模型
model = resnet34(num_classes=4)
# 加载权重
model_weight_path = 'model/garbage_classify_model.pth'
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img))
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].numpy())
plt.show()
plt.pause(2) # 间隔的秒数:6s
plt.close()
return int(predict_cla)
if __name__ == '__main__':
img = Image.open("test_image/img_15654.jpg")
garbage_class = predict_garbage(img)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/qiqidolikelhy/intelligent-trash-can.git
git@gitee.com:qiqidolikelhy/intelligent-trash-can.git
qiqidolikelhy
intelligent-trash-can
智能垃圾桶
master

搜索帮助