代码拉取完成,页面将自动刷新
#coding=utf-8
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
from data_loader import KFDataset
from models import KFSGNet
from train import config,get_peak_points,get_mse
def demo(img,heatmaps):
"""
:param img: (96,96)
:param heatmaps: ()
:return:
"""
# img = img.reshape(96, 96)
# axis.imshow(img, cmap='gray')
# axis.scatter(y[:, 0], y[:, 1], marker='x', s=10)
pass
def evaluate():
# 加载模型
net = KFSGNet()
net.float().cuda()
net.eval()
if (config['checkout'] != ''):
net.load_state_dict(torch.load(config['checkout']))
dataset = KFDataset(config)
dataset.load()
dataLoader = DataLoader(dataset,1)
for i,(images,_,gts) in enumerate(dataLoader):
images = Variable(images).float().cuda()
pred_heatmaps = net.forward(images)
demo_img = images[0].cpu().data.numpy()[0]
demo_img = (demo_img * 255.).astype(np.uint8)
demo_heatmaps = pred_heatmaps[0].cpu().data.numpy()[np.newaxis,...]
demo_pred_poins = get_peak_points(demo_heatmaps)[0] # (15,2)
plt.imshow(demo_img,cmap='gray')
plt.scatter(demo_pred_poins[:,0],demo_pred_poins[:,1])
plt.show()
# loss = get_mse(demo_pred_poins[np.newaxis,...],gts)
if __name__ == '__main__':
evaluate()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。