1 Star 0 Fork 0

ideaoverflow/CRAFT-Reimplementation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
mseloss.py 2.15 KB
一键复制 编辑 原始数据 按行查看 历史
严海 提交于 2019-09-10 11:27 . update
import numpy as np
import torch
import torch.nn as nn
class Maploss(nn.Module):
def __init__(self, use_gpu = True):
super(Maploss,self).__init__()
def single_image_loss(self, pre_loss, loss_label):
batch_size = pre_loss.shape[0]
sum_loss = torch.mean(pre_loss.view(-1))*0
pre_loss = pre_loss.view(batch_size, -1)
loss_label = loss_label.view(batch_size, -1)
internel = batch_size
for i in range(batch_size):
average_number = 0
loss = torch.mean(pre_loss.view(-1)) * 0
positive_pixel = len(pre_loss[i][(loss_label[i] >= 0.1)])
average_number += positive_pixel
if positive_pixel != 0:
posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= 0.1)])
sum_loss += posi_loss
if len(pre_loss[i][(loss_label[i] < 0.1)]) < 3*positive_pixel:
nega_loss = torch.mean(pre_loss[i][(loss_label[i] < 0.1)])
average_number += len(pre_loss[i][(loss_label[i] < 0.1)])
else:
nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < 0.1)], 3*positive_pixel)[0])
average_number += 3*positive_pixel
sum_loss += nega_loss
else:
nega_loss = torch.mean(torch.topk(pre_loss[i], 500)[0])
average_number += 500
sum_loss += nega_loss
#sum_loss += loss/average_number
return sum_loss
def forward(self, gh_label, gah_label, p_gh, p_gah, mask):
gh_label = gh_label
gah_label = gah_label
p_gh = p_gh
p_gah = p_gah
loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
assert p_gh.size() == gh_label.size() and p_gah.size() == gah_label.size()
loss1 = loss_fn(p_gh, gh_label)
loss2 = loss_fn(p_gah, gah_label)
loss_g = torch.mul(loss1, mask)
loss_a = torch.mul(loss2, mask)
char_loss = self.single_image_loss(loss_g, gh_label)
affi_loss = self.single_image_loss(loss_a, gah_label)
return char_loss/loss_g.shape[0] + affi_loss/loss_a.shape[0]
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ideaoverflow/CRAFT-Reimplementation.git
git@gitee.com:ideaoverflow/CRAFT-Reimplementation.git
ideaoverflow
CRAFT-Reimplementation
CRAFT-Reimplementation
master

搜索帮助