1 Star 0 Fork 0

yyiOe/AEAD

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
detection.py 2.26 KB
一键复制 编辑 原始数据 按行查看 历史
yyiOe 提交于 2021-06-04 16:18 . mosaic augmentation
import torch
from criterions import SquareError
import numpy as np
import torch.nn.functional as F
from metrics import per_pixel_auc, aupro
from criterions import FeatureMapLoss
@torch.no_grad()
def get_score_map(model, input):
criterion = SquareError(is_loss=False)
output = model(input)
score_map = criterion(input, output)
return score_map.cpu().numpy()
def test_location(model, test_dataloader):
score_maps = []
gts = []
for batch in test_dataloader:
img, _, gt, _, _ = batch
img = img.cuda()
score_map = get_score_map(model, img)
score_maps.append(score_map)
gts.append(gt)
score_maps = np.concatenate(score_maps)
gts = torch.cat(gts)
n_s, a_s = count_sum(score_maps, gts)
ppauc = per_pixel_auc(score_maps, gts)
return ppauc, n_s, a_s
def count_sum(scores, gts):
if isinstance(scores, torch.Tensor):
scores = scores.numpy()
if isinstance(gts, torch.Tensor):
gts = gts.numpy()
scores = scores.reshape(-1)
gts = gts.reshape(-1)
anomaly_score = (scores * (gts == 1)).sum() / gts.sum()
normality_score = (scores * (gts == 0)).sum() / (1 - gts).sum()
return normality_score, anomaly_score
@torch.no_grad()
def feature_score_map(input, mask, model):
criterion = FeatureMapLoss(False)
input_shape = input.shape[-2:]
ae_output, feature_outputs = model(input, mask)
score_map = []
for i in feature_outputs:
score_map.append(
F.interpolate(criterion(i[0], i[1]).unsqueeze(1), size=input_shape, mode='bilinear'))
score_map_ = (sum(score_map) / len(score_map)).squeeze(dim=1)
return score_map_.detach().cpu().numpy()
def evaluate_localization(model, val_dataloader, need_pro=False):
score_maps = []
gts = []
for bi, batch in enumerate(val_dataloader):
img, _, gt, _, _, ms = batch
img = img.cuda()
score_maps_collect = feature_score_map(img, ms, model)
score_maps.append(score_maps_collect)
gts.append(gt)
score_maps = np.concatenate(score_maps)
gts = torch.cat(gts)
auc = per_pixel_auc(score_maps, gts)
n_s, a_s = count_sum(score_maps, gts)
if need_pro:
pro = aupro(score_maps, gts)
else:
pro = None
return auc, n_s, a_s, pro
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yyioe/aead.git
git@gitee.com:yyioe/aead.git
yyioe
aead
AEAD
master

搜索帮助