代码拉取完成,页面将自动刷新
import os
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from models import STSimMS, FeatWeight
from config import cfg
from datasets import MvTec
from localization import get_ms_featmaps
def denorm(tensor):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.mul_(std).add_(mean)
return tensor
if __name__ == '__main__':
class_name = 'bottle'
ckpt_path = 'checkpoints/ms/resnet50_mssearch/mvtec'
model_file = os.path.join(ckpt_path, class_name, 'epoch_best.pth')
modeldata = torch.load(model_file)
model = STSimMS('resnet50', len(cfg.img_size),
torch.ones((len(cfg.img_size), cfg.target_layer_num), dtype=torch.bool)).cuda()
model.simulator.load_state_dict(modeldata['model_state'])
featweight = FeatWeight(len(cfg.img_size) * cfg.target_layer_num, [1, 1, 1, 1]).cuda()
dataset = MvTec('test', class_name, cfg.img_size, val_list=[])
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
dst_path = 'vis/' + class_name
if not os.path.exists(dst_path):
os.mkdir(dst_path)
idx = 0
with torch.no_grad():
for bi, batch in enumerate(dataloader):
img, _, gt, _ = batch
feat_maps = get_ms_featmaps(img, model)
feat_maps = featweight(feat_maps)
for i in range(img[0].shape[0]):
image = img[0][i].cpu()
image = denorm(image)
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
plt.imshow(image.permute(1, 2, 0))
plt.savefig(dst_path + '/{:0>3d}_image.jpg'.format(idx))
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
grt = gt[i].cpu()
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
plt.imshow(grt[0])
plt.savefig(dst_path + '/{:0>3d}_gt.jpg'.format(idx))
score = feat_maps[i].cpu()
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
plt.imshow(image.permute(1, 2, 0))
plt.imshow(score[0], alpha=0.4, cmap=plt.get_cmap('jet'))
plt.savefig(dst_path + '/{:0>3d}_score.jpg'.format(idx))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。