代码拉取完成,页面将自动刷新
import torch
from sklearn.metrics import auc, roc_auc_score, roc_curve
import numpy as np
import cv2
def per_pixel_auc(input, label, max_fpr=None):
if isinstance(input, torch.Tensor):
input = input.cpu().numpy()
if isinstance(label, torch.Tensor):
label = label.cpu().numpy()
input = input.reshape(-1)
label = label.reshape(-1)
return roc_auc_score(label, input, max_fpr=max_fpr)
def gts_with_connectedComponents(gts):
gts = gts.astype(np.uint8)
result = []
val = []
for i in range(gts.shape[0]):
gt = gts[i][0]
retval, labels = cv2.connectedComponents(gt)
result.append(labels)
val.append(retval)
result = np.stack(result, axis=0)
val = np.stack(val, axis=0)
return val, result
def all_map_pro(maps, gts):
all_result = []
num = maps.shape[0]
for i in range(num):
map = maps[i]
gt = gts[i]
val = np.unique(gt)
for v in range(1, np.max(val) + 1):
component = np.where(gt == v, 1, 0)
inter = np.sum(np.logical_and(map, component))
all_result.append(inter / np.sum(component))
return np.array(all_result).mean()
def aupro(preds, gts):
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
if isinstance(gts, torch.Tensor):
gts = gts.cpu().numpy()
roc_input = preds.reshape(-1)
roc_target = gts.reshape(-1)
fpr, _, thresholds = roc_curve(roc_target, roc_input)
fpr_loc = np.where(fpr <= 0.3)[0]
fpr = fpr[fpr_loc]
thresholds = thresholds[fpr_loc]
loc = np.arange(0, len(fpr), len(fpr) // 1000)
fpr = fpr[loc]
thresholds = thresholds[loc]
val, labels = gts_with_connectedComponents(gts)
pros = []
for i, th in enumerate(thresholds):
b_maps = np.where(preds >= th, 1, 0)
pros.append(all_map_pro(b_maps, labels))
aupro = auc(fpr, pros)
return aupro / 0.3
if __name__ == '__main__':
dummy_gt = np.zeros((2, 1, 35, 35), dtype=np.int32)
dummy_gt[0, 0, 5:15, 5:15] = 1
dummy_gt[0, 0, 5:15, 20:30] = 1
dummy_gt[0, 0, 20:30, 5:15] = 1
dummy_gt[0, 0, 20:30, 20:30] = 1
map = np.zeros((2, 35, 35), dtype=np.int32)
map[0, 5:15, 5:10] = 1
map[0, 20:25, 25:30] = 1
val, com_gt = gts_with_connectedComponents(dummy_gt)
print(all_map_pro(map, com_gt))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。