1 Star 0 Fork 0

LZY/Intrinsic-Image-Popularity

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 1.21 KB
一键复制 编辑 原始数据 按行查看 历史
dingkeyan 提交于 2020-01-09 20:35 . Update test.py
# -*- coding: utf-8 -*-
import argparse
import torch
import torchvision.models
import torchvision.transforms as transforms
from PIL import Image
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def prepare_image(image):
if image.mode != 'RGB':
image = image.convert("RGB")
Transform = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
])
image = Transform(image)
image = image.unsqueeze(0)
return image.to(device)
def predict(image, model):
image = prepare_image(image)
with torch.no_grad():
preds = model(image)
print(r'Popularity score: %.2f' % preds.item())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image_path', type=str, default='images/0.jpg')
config = parser.parse_args()
image = Image.open(config.image_path)
model = torchvision.models.resnet50()
# model.avgpool = nn.AdaptiveAvgPool2d(1) # for any size of the input
model.fc = torch.nn.Linear(in_features=2048, out_features=1)
model.load_state_dict(torch.load('model/model-resnet50.pth', map_location=device))
model.eval().to(device)
predict(image, model)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/LZY2006/Intrinsic-Image-Popularity.git
git@gitee.com:LZY2006/Intrinsic-Image-Popularity.git
LZY2006
Intrinsic-Image-Popularity
Intrinsic-Image-Popularity
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385