1 Star 0 Fork 1

thb1314/GAN-defect

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
loss.py 860 Bytes
一键复制 编辑 原始数据 按行查看 历史
KEFEI.HU 提交于 2020-04-26 18:40 . add segmentation branch (modified fcn)
from distutils.version import LooseVersion
import torch
import torch.nn.functional as F
import torch.nn as nn
def cross_entropy2d(input, target, weight=None, size_average=True):
# input: (n, c, h, w), target: (n, h, w)
n, c, h, w = input.size()
# log_p: (n, c, h, w)
if LooseVersion(torch.__version__) < LooseVersion('0.3'):
# ==0.2.X
log_p = F.log_softmax(input)
else:
# >=0.3
log_p = F.log_softmax(input, dim=1)
# log_p: (n*h*w, c)
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
log_p = log_p.view(-1, c)
# target: (n*h*w,)
mask = target >= 0
target = target[mask]
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
if size_average:
loss /= mask.data.sum()
return loss
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/thb1314/GAN-defect.git
git@gitee.com:thb1314/GAN-defect.git
thb1314
GAN-defect
GAN-defect
master

搜索帮助