1 Star 0 Fork 0

Joe/DFANet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
criterion.py 2.00 KB
一键复制 编辑 原始数据 按行查看 历史
shenhuxiang 提交于 2019-04-10 16:42 . .
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import numpy as np
from torch.nn import functional as F
from torch.autograd import Variable
from loss import OhemCrossEntropy2d
import scipy.ndimage as nd
class CriterionDSN(nn.Module):
'''
DSN : We need to consider two supervision for the model.
'''
def __init__(self, ignore_index=255, use_weight=True, reduce=True):
super(CriterionDSN, self).__init__()
self.ignore_index = ignore_index
self.criterion = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='mean')
if not reduce:
print("disabled the reduce.")
def forward(self, preds, target):
h, w = target.size(1), target.size(2)
scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
loss1 = self.criterion(scale_pred, target)
scale_pred = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
loss2 = self.criterion(scale_pred, target)
return loss1 + loss2*0.4
class CriterionOhemDSN(nn.Module):
'''
DSN : We need to consider two supervision for the model.
'''
def __init__(self, ignore_index=255, thresh=0.7, min_kept=100000, use_weight=True, reduce=True):
super(CriterionOhemDSN, self).__init__()
self.ignore_index = ignore_index
self.criterion1 = OhemCrossEntropy2d(ignore_index, thresh, min_kept)
self.criterion2 = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduce=reduce)
def forward(self, preds, target):
h, w = target.size(1), target.size(2)
scale_pred = F.interpolate(input=preds[0], size=(h, w), mode='bilinear', align_corners=True)
loss1 = self.criterion1(scale_pred, target)
scale_pred = F.interpolate(input=preds[1], size=(h, w), mode='bilinear', align_corners=True)
loss2 = self.criterion2(scale_pred, target)
return loss1 + loss2*0.4
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhou_rx/DFANet.git
git@gitee.com:zhou_rx/DFANet.git
zhou_rx
DFANet
DFANet
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385