1 Star 0 Fork 0

yyiOe/PyramidST

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
vis.py 2.61 KB
一键复制 编辑 原始数据 按行查看 历史
yyiOe 提交于 2021-12-24 17:40 . cutpaste
import os
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from models import STSimMS, FeatWeight
from config import cfg
from datasets import MvTec
from localization import get_ms_featmaps
def denorm(tensor):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
tensor.mul_(std).add_(mean)
return tensor
if __name__ == '__main__':
class_name = 'bottle'
ckpt_path = 'checkpoints/ms/resnet50_mssearch/mvtec'
model_file = os.path.join(ckpt_path, class_name, 'epoch_best.pth')
modeldata = torch.load(model_file)
model = STSimMS('resnet50', len(cfg.img_size),
torch.ones((len(cfg.img_size), cfg.target_layer_num), dtype=torch.bool)).cuda()
model.simulator.load_state_dict(modeldata['model_state'])
featweight = FeatWeight(len(cfg.img_size) * cfg.target_layer_num, [1, 1, 1, 1]).cuda()
dataset = MvTec('test', class_name, cfg.img_size, val_list=[])
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
dst_path = 'vis/' + class_name
if not os.path.exists(dst_path):
os.mkdir(dst_path)
idx = 0
with torch.no_grad():
for bi, batch in enumerate(dataloader):
img, _, gt, _ = batch
feat_maps = get_ms_featmaps(img, model)
feat_maps = featweight(feat_maps)
for i in range(img[0].shape[0]):
image = img[0][i].cpu()
image = denorm(image)
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
plt.imshow(image.permute(1, 2, 0))
plt.savefig(dst_path + '/{:0>3d}_image.jpg'.format(idx))
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
grt = gt[i].cpu()
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
plt.imshow(grt[0])
plt.savefig(dst_path + '/{:0>3d}_gt.jpg'.format(idx))
score = feat_maps[i].cpu()
# plt.subplots_adjust(left=0, right=0, top=0, bottom=0)
plt.imshow(image.permute(1, 2, 0))
plt.imshow(score[0], alpha=0.4, cmap=plt.get_cmap('jet'))
plt.savefig(dst_path + '/{:0>3d}_score.jpg'.format(idx))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/yyioe/pyramid-st.git
git@gitee.com:yyioe/pyramid-st.git
yyioe
pyramid-st
PyramidST
master

搜索帮助