代码拉取完成,页面将自动刷新
import torch
from criterions import SquareError
import numpy as np
import torch.nn.functional as F
from metrics import per_pixel_auc, aupro
from criterions import FeatureMapLoss
@torch.no_grad()
def get_score_map(model, input):
criterion = SquareError(is_loss=False)
output = model(input)
score_map = criterion(input, output)
return score_map.cpu().numpy()
def test_location(model, test_dataloader):
score_maps = []
gts = []
for batch in test_dataloader:
img, _, gt, _, _ = batch
img = img.cuda()
score_map = get_score_map(model, img)
score_maps.append(score_map)
gts.append(gt)
score_maps = np.concatenate(score_maps)
gts = torch.cat(gts)
n_s, a_s = count_sum(score_maps, gts)
ppauc = per_pixel_auc(score_maps, gts)
return ppauc, n_s, a_s
def count_sum(scores, gts):
if isinstance(scores, torch.Tensor):
scores = scores.numpy()
if isinstance(gts, torch.Tensor):
gts = gts.numpy()
scores = scores.reshape(-1)
gts = gts.reshape(-1)
anomaly_score = (scores * (gts == 1)).sum() / gts.sum()
normality_score = (scores * (gts == 0)).sum() / (1 - gts).sum()
return normality_score, anomaly_score
@torch.no_grad()
def feature_score_map(input, mask, model):
criterion = FeatureMapLoss(False)
input_shape = input.shape[-2:]
ae_output, feature_outputs = model(input, mask)
score_map = []
for i in feature_outputs:
score_map.append(
F.interpolate(criterion(i[0], i[1]).unsqueeze(1), size=input_shape, mode='bilinear'))
score_map_ = (sum(score_map) / len(score_map)).squeeze(dim=1)
return score_map_.detach().cpu().numpy()
def evaluate_localization(model, val_dataloader, need_pro=False):
score_maps = []
gts = []
for bi, batch in enumerate(val_dataloader):
img, _, gt, _, _, ms = batch
img = img.cuda()
score_maps_collect = feature_score_map(img, ms, model)
score_maps.append(score_maps_collect)
gts.append(gt)
score_maps = np.concatenate(score_maps)
gts = torch.cat(gts)
auc = per_pixel_auc(score_maps, gts)
n_s, a_s = count_sum(score_maps, gts)
if need_pro:
pro = aupro(score_maps, gts)
else:
pro = None
return auc, n_s, a_s, pro
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。