1 Star 0 Fork 29

玖沐星辰/2023_pytorch110_classification_42

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 3.01 KB
一键复制 编辑 原始数据 按行查看 历史
肆十二 提交于 2022-12-06 18:00 . 6
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :cls_template
@File :predict.py
@Author :ChenmingSong
@Date :2022/1/5 16:23
@Description:用来推理数据集
'''
import torch
# from train_resnet import SelfNet
from train import SELFMODEL
import os
import os.path as osp
import shutil
import torch.nn as nn
from PIL import Image
from torchutils import get_torch_transforms
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
model_path = "../checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99501_weights.pth" # todo 模型路径
classes_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips'] # todo 类名
img_size = 224 # todo 图片大小
model_name = "resnet50d" # todo 模型名称
num_classes = len(classes_names) # todo 类别数目
def predict_batch(model_path, target_dir, save_dir):
data_transforms = get_torch_transforms(img_size=img_size)
valid_transforms = data_transforms['val']
# 加载网络
model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)
# model = nn.DataParallel(model)
weights = torch.load(model_path)
model.load_state_dict(weights)
model.eval()
model.to(device)
# 读取图片
image_names = os.listdir(target_dir)
for i, image_name in enumerate(image_names):
image_path = osp.join(target_dir, image_name)
img = Image.open(image_path)
img = valid_transforms(img)
img = img.unsqueeze(0)
img = img.to(device)
output = model(img)
label_id = torch.argmax(output).item()
predict_name = classes_names[label_id]
save_path = osp.join(save_dir, predict_name)
if not osp.isdir(save_path):
os.makedirs(save_path)
shutil.copy(image_path, save_path)
print(f"{i + 1}: {image_name} result {predict_name}")
def predict_single(model_path, image_path):
data_transforms = get_torch_transforms(img_size=img_size)
# train_transforms = data_transforms['train']
valid_transforms = data_transforms['val']
# 加载网络
model = SELFMODEL(model_name=model_name, out_features=num_classes, pretrained=False)
# model = nn.DataParallel(model)
weights = torch.load(model_path)
model.load_state_dict(weights)
model.eval()
model.to(device)
# 读取图片
img = Image.open(image_path)
img = valid_transforms(img)
img = img.unsqueeze(0)
img = img.to(device)
output = model(img)
label_id = torch.argmax(output).item()
predict_name = classes_names[label_id]
print(f"{image_path}'s result is {predict_name}")
if __name__ == '__main__':
# 批量预测函数
predict_batch(model_path=model_path,
target_dir="D:/upppppppppp/cls/cls_torch_tem/images/test_imgs/mini",
save_dir="D:/upppppppppp/cls/cls_torch_tem/images/test_imgs/mini_result")
# 单张图片预测函数
# predict_single(model_path=model_path, image_path="images/test_imgs/506659320_6fac46551e.jpg")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jiumuxingchen/2023_pytorch110_classification_42.git
git@gitee.com:jiumuxingchen/2023_pytorch110_classification_42.git
jiumuxingchen
2023_pytorch110_classification_42
2023_pytorch110_classification_42
master

搜索帮助