2 Star 1 Fork 0

CJLU2021/table-seg

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
my_train_table_segmentation.py 11.66 KB
一键复制 编辑 原始数据 按行查看 历史
syshensyshen 提交于 2021-07-27 21:30 . table segment
import math
import numpy as np
import cv2
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
import os
import sys
print(sys.path)
print(os.getcwd())
import torch
import torch.nn as nn
from datasets.table.dataloader import LoadTableImageAndLabels
from models.assembly.segmentation_table import Segmentation_Model
from models.assembly.deeplab import DeepLabV3
from models.assembly.my_pan_model import PanModel
from dice_loss import dice_coeff
def save_tensor(tensor, i, e):
np_array = tensor[0].detach().cpu().numpy().transpose(1, 2, 0)
# mx, mn = np_array.max(), np_array.min()
# arr = (np_array - mn) / (mx - mn) * 255
np_array = np.array(np_array * 255, np.uint8)[:, :, 0]
# np_array = cv2.resize(np_array, (320, 200))
cv2.imwrite('results/' + str(e) + '-' + str(i) + '.jpg', np_array)
def _iou(pred, target, size_average=True):
b = pred.shape[0]
IoU = 0.0
for i in range(0, b):
# compute the IoU of the foreground
Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
IoU1 = Iand1 / Ior1
# IoU loss is (1-IoU1)
IoU = IoU + (1 - IoU1)
return IoU / b
class IOU(torch.nn.Module):
def __init__(self, size_average=True):
super(IOU, self).__init__()
self.size_average = size_average
def forward(self, pred, target):
return _iou(pred, target, self.size_average)
class BCEFocalLoss(torch.nn.Module):
def __init__(self, gamma=2, alpha=0.6, reduction='elementwise_mean'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, _input, target):
_input[_input <= 1e-4] = 1e-4
alpha = self.alpha
loss = - alpha * (1 - _input) ** self.gamma * target * torch.log(_input) - \
(1 - alpha) * _input ** self.gamma * (1 - target) * torch.log(1 - _input)
if self.reduction == 'elementwise_mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.mean()
return loss
def balance_mask(score, label, mask):
pos_num = label[label > 0.5].numel() # 返回元数个数
selected_mask = torch.zeros_like(label)
if pos_num == 0:
selected_mask = torch.ones_like(label)
return selected_mask
selected_mask[label > 0.5] = 1.0
# 正负样本比例为1:3
# neg_num = label[label <= 0.5].numel()
# neg_num = (int)(min(pos_num * 3, neg_num))
# 负样本在正例的基础上膨胀几个像素点。取个数
neg_num = mask[mask > 0.5].numel()
if neg_num == 0:
return selected_mask
neg_score = score[label <= 0.5]
neg_score_sorted = torch.sort(-neg_score)
threshold = -neg_score_sorted[0][neg_num - 1]
selected_mask[score >= threshold] = 1.0
#
# selected_mask[score <= 0.5] = 0.0
# total_num = selected_mask[selected_mask > 0.0].numel()
return selected_mask, pos_num + neg_num
# return selected_mask, total_num
def pos_hard_mining(outputs, targets, mask):
# 第一种方式
balan_mask, total_num = balance_mask(outputs, targets, mask)
loss = F.binary_cross_entropy(outputs, targets, reduction='none')
# pos_loss = loss.mul(mask).mean()#与loss做点乘之后去均值
# 第二种方式:
loss = ((loss * 3).mul(balan_mask).sum() / total_num).mean()
# loss = (3*((loss * 2).mul(balan_mask).sum() / total_num).mean()+loss.mean()).mean()
return loss
def hard_mining(outputs, targets):
# loss = nn.CrossEntropyLoss(outputs, targets)
# return loss
loss = F.binary_cross_entropy(outputs, targets, reduction='none')
_, topk_loss_inds = loss.reshape(-1).topk(loss.reshape(-1).numel() // 2)
return loss.reshape(-1)[topk_loss_inds].mean()
class SoftDiceLoss(torch.nn.Module):
def __init__(self, weight=None, size_average=True):
super(SoftDiceLoss, self).__init__()
def forward1(self, logits, targets):
num = targets.size(0)
smooth = 1
# probs = F.sigmoid(logits)
m1 = logits.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2)
score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
score = 1 - score.sum() / num
return score
def forward(self, logits, targets):
num = targets.size(0)
smooth = 1
# logits[logits<0.5] = 0.0
# probs = F.sigmoid(logits)
m1 = logits.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2)
score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
score = 1 - score.sum() / num
return score
class TverskyLoss(torch.nn.Module):
def __init__(self, weight=None, size_average=True): # https://zhuanlan.zhihu.com/p/103426335
super(TverskyLoss, self).__init__()
def forward(self, logits, targets):
num = targets.size(0)
smooth = 1
# probs = F.sigmoid(logits)
m1 = logits.view(num, -1)
m2 = targets.view(num, -1)
true_pos = (m2 * m1).sum(1)
false_neg = (m2 * (1 - m1)).sum(1)
false_pos = ((1 - m2) * m1).sum(1)
alpha = 0.7
score = (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)
return score.sum() / num
# DeepLabV3网络模型
def main():
# data_fd = '/mnt/data/ocr_data/table_data/merge_data_mask_v1'
data_fd = '/mnt/data/ocr_data/table_data/contract_bank_merge_860_860_v1'
# data_fd = '/mnt/data/xp/datasets/table/images'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
dataset = LoadTableImageAndLabels(data_fd)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=24, num_workers=0, shuffle=True, pin_memory=True)
model = DeepLabV3(1).to(device)
# model.load_state_dict(torch.load('./ckpt/hard_mining_v1.3_190.pth'))
loss_fn = BCEFocalLoss()
loss_softdice = SoftDiceLoss()
loss_tversky = TverskyLoss()
epochs = 20000000
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# scheduler.last_epoch = epochs - 1 # do not move
# scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[64, 128, 168], gamma=0.1)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3, 7, 13], gamma=0.1)
for epoch in range(0, epochs):
model.train()
print('epoch: ', epoch)
for i, (imgs, targets, mask) in enumerate(dataloader):
imgs = imgs.to(device).float()
outputs = model(imgs).sigmoid()
loss1 = pos_hard_mining(outputs, targets.to(device), mask.to(device))
loss2 = loss_softdice(outputs, targets.to(device))
loss = (loss1 + loss2 * 3).mean()
# if epoch >= 0 and epoch < 200:
# loss = pos_hard_mining(outputs, targets.to(device))
# elif epoch >= 200:
# # loss = loss_fn(outputs, targets.to(device))
# # loss = hard_mining(outputs, targets.to(device))
# # outputs = (outputs > 0.5).float()
# # loss = dice_coeff(outputs, targets.to(device)).mean()
# loss = pos_hard_mining(outputs, targets.to(device))
# # loss.requires_grad = True
# else:
# loss = F.binary_cross_entropy(outputs, targets.to(device))
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 10 == 0:
save_tensor(outputs, i, epoch)
print('loss is:', loss.cpu().item())
# print(i, pred.max(), pred.min(), loss)
scheduler.step()
if epoch and epoch % 10 == 0:
torch.save(model.state_dict(), './ckpt/hard_mining_v1.9_' + str(epoch) + '.pth')
# 1.4验证把负样本设置为正样本周围的个数
# 1.5 验证把predect结果小于0.5的部分全部置0
# 1.6 loss全部替换成softdiceloss
# 1.7 loss全部替换成 TverskyLoss
# 1.8 loss=正负样本比例1:3*0.8+softdiceloss*0.2
# 1.9 loss=正负样本比例1:3*0.5+softdiceloss*0.5
# 2.0 loss=正负样本比例1:3*3+softdiceloss
# 2.1 loss=正负样本比例1:3+softdiceloss*3
# pannet网络模型
def main_pannet():
data_fd = '/mnt/data/ocr_data/table_data/table_labelme_v1'
# data_fd = '/mnt/data/xp/datasets/table/images'
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
dataset = LoadTableImageAndLabels(data_fd)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=10, shuffle=True, pin_memory=True)
# model = DeepLabV3(1).to(device)
model = PanModel().to(device)
# model.load_state_dict(torch.load('./ckpt/pan_net_resnet50_gc_860_v1.5_140.pth'))
loss_fn = BCEFocalLoss()
loss_softdice = SoftDiceLoss()
loss_tversky = TverskyLoss()
loss_iou = IOU()
epochs = 20000000
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
# scheduler.last_epoch = epochs - 1 # do not move
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 100], gamma=0.1)
# scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3, 7, 13], gamma=0.1)
for epoch in range(0, epochs):
model.train()
print('epoch: ', epoch)
for i, (imgs, targets, mask) in enumerate(dataloader):
imgs = imgs.to(device).float()
outputs = model(imgs).sigmoid()
loss1 = pos_hard_mining(outputs, targets.to(device), mask.to(device))
loss2 = loss_softdice(outputs, targets.to(device))
loss3 = loss_iou(outputs, targets.to(device))
# loss = (loss1+loss2*0).mean()#难虑
loss = (loss1 * 3 + loss2 * 2 + loss3).mean() # dice
# if epoch >= 0 and epoch < 200:
# loss = pos_hard_mining(outputs, targets.to(device))
# elif epoch >= 200:
# # loss = loss_fn(outputs, targets.to(device))
# # loss = hard_mining(outputs, targets.to(device))
# # outputs = (outputs > 0.5).float()
# # loss = dice_coeff(outputs, targets.to(device)).mean()
# loss = pos_hard_mining(outputs, targets.to(device))
# # loss.requires_grad = True
# else:
# loss = F.binary_cross_entropy(outputs, targets.to(device))
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 10 == 0:
save_tensor(outputs, i, epoch)
print('loss is:', loss.cpu().item())
# print(i, pred.max(), pred.min(), loss)
scheduler.step()
if epoch and epoch % 10 == 0:
torch.save(model.state_dict(), './ckpt/pan_net_resnet50_gc_860_v1.5.1_' + str(epoch) + '.pth')
if __name__ == '__main__':
main_pannet()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cjlu2021/table-seg.git
git@gitee.com:cjlu2021/table-seg.git
cjlu2021
table-seg
table-seg
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385