1 Star 0 Fork 0

gvraky/EAST3

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
loss.py 1.98 KB
一键复制 编辑 原始数据 按行查看 历史
dejiasong 提交于 2018-08-29 15:30 . Add files via upload
import torch
from torch.autograd import Variable
### 此处默认真实值和预测值的格式均为 bs * W * H * channels
import torch
import torch.nn as nn
def dice_coefficient(y_true_cls, y_pred_cls,
training_mask):
'''
dice loss
:param y_true_cls:
:param y_pred_cls:
:param training_mask:
:return:
'''
eps = 1e-5
intersection =torch.sum(y_true_cls * y_pred_cls * training_mask)
union = torch.sum(y_true_cls * training_mask) + torch.sum(y_pred_cls * training_mask) + eps
loss = 1. - (2 * intersection / union)
return loss
class LossFunc(nn.Module):
def __init__(self):
super(LossFunc, self).__init__()
return
def forward(self, y_true_cls, y_pred_cls, y_true_geo, y_pred_geo, training_mask):
classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask)
# scale classification loss to match the iou loss part
classification_loss *= 0.01
# d1 -> top, d2->right, d3->bottom, d4->left
# d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3)
d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = torch.split(y_true_geo, 1, 1)
# d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3)
d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = torch.split(y_pred_geo, 1, 1)
area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt)
area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred)
w_union = torch.min(d2_gt, d2_pred) + torch.min(d4_gt, d4_pred)
h_union = torch.min(d1_gt, d1_pred) + torch.min(d3_gt, d3_pred)
area_intersect = w_union * h_union
area_union = area_gt + area_pred - area_intersect
L_AABB = -torch.log((area_intersect + 1.0)/(area_union + 1.0))
L_theta = 1 - torch.cos(theta_pred - theta_gt)
L_g = L_AABB + 20 * L_theta
return torch.mean(L_g * y_true_cls * training_mask) + classification_loss
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/gvraky/east3.git
git@gitee.com:gvraky/east3.git
gvraky
east3
EAST3
master

搜索帮助