1 Star 0 Fork 2

alex/GAN-defect

forked from Joe/GAN-defect 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 2.14 KB
一键复制 编辑 原始数据 按行查看 历史
KEFEI.HU 提交于 2020-04-28 18:20 . using elu activation function
import numpy as np
from collections import OrderedDict
import torch
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def getavg(self):
return self.avg
def label_accuracy_score(label_trues, label_preds, n_class):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = np.zeros((n_class, n_class))
for lt, lp in zip(label_trues, label_preds):
hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
acc = np.diag(hist).sum() / hist.sum()
with np.errstate(divide='ignore', invalid='ignore'):
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
with np.errstate(divide='ignore', invalid='ignore'):
iu = np.diag(hist) / (
hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
)
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
return acc, acc_cls, mean_iu, fwavacc
def _fast_hist(label_true, label_pred, n_class):
mask = (label_true >= 0) & (label_true < n_class)
hist = np.bincount(
n_class * label_true[mask].astype(int) +
label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
return hist
def modify_checkpoint(model, checkpoint):
model_state_dict = model.state_dict()
checkpoint = checkpoint.copy()
new_ckpt = OrderedDict()
for k, v in checkpoint.items():
if k not in model_state_dict.keys():
continue
else:
model_p = model_state_dict[k]
if v.shape != model_p.shape:
continue
new_ckpt[k] = v
print(new_ckpt.keys())
return new_ckpt
def elu(x, alpha):
if x >= 0:
return x
else:
return alpha * (torch.exp(x) - 1)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/hyx007/GAN-defect.git
git@gitee.com:hyx007/GAN-defect.git
hyx007
GAN-defect
GAN-defect
master

搜索帮助