代码拉取完成,页面将自动刷新
同步操作将从 tzcfly/remote-sensing-semantic-segmentation 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from __future__ import print_function
# 使python2.x的print语法与python3.x的print规则一样
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
from torch.autograd import Variable
import time
import numpy as np
from numpy import *
from data_loader.dataset import train_dataset
from models.u_net import UNet
from models.seg_net import Segnet
from models.fcn import FCN8s, VGGNet
#####################################################################################
# ToDo:是否使用多通道Capsule
#####################################################################################
# from models.capsule import CapsuleNet, CapsuleLoss
from models.multi_capsule import CapsuleNet, CapsuleLoss
from utils.metrics import Evaluator
parser = argparse.ArgumentParser(description='Training a RS_Semantic_Segmentation model')
parser.add_argument('--batch_size', type=int, default=4, help='equivalent to instance normalization with batch_size=1')
parser.add_argument('--input_nc', type=int, default=3)
parser.add_argument('--output_nc', type=int, default=1)
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', type=bool,default=True, help='enables cuda. default=True')
parser.add_argument('--manual_seed', type=int, default=2021, help='manual seed') # 手动随机种子
parser.add_argument('--num_workers', type=int, default=0, help='how many threads of cpu to use while loading data')
parser.add_argument('--size_w', type=int, default=256, help='scale image to this size')
parser.add_argument('--size_h', type=int, default=256, help='scale image to this size')
parser.add_argument('--flip', type=int, default=0, help='1 for flipping image randomly, 0 for not')
parser.add_argument('--net', type=str, default='', help='path to pre-trained network')
parser.add_argument('--data_path', default='./data/train', help='path to training images')
parser.add_argument('--val_data_path', default='./data/val', help='path to validation images')
#####################################################################################
# ToDo:choose the model which will be trained
#####################################################################################
#parser.add_argument('--outf', default='./checkpoint/Unet', help='folder to output images and model checkpoints')
#parser.add_argument('--outf', default='./checkpoint/Segnet', help='folder to output images and model checkpoints')
#parser.add_argument('--outf', default='./checkpoint/FCN', help='folder to output images and model checkpoints')
#parser.add_argument('--outf', default='./checkpoint/Capsule', help='folder to output images and model checkpoints')
parser.add_argument('--outf', default='./checkpoint/MultiCapsule', help='folder to output images and model checkpoints')
parser.add_argument('--save_epoch', default=1, help='number of epoch to save parameters')
parser.add_argument('--test_step', default=300, help='number of step to eval model')
parser.add_argument('--log_step', default=1, help='number of step to write log')
parser.add_argument('--num_GPU', default=1, help='number of GPU')
opt = parser.parse_args()
try:
os.makedirs(opt.outf)
os.makedirs(opt.outf + '/model/')
except OSError:
pass
if opt.manual_seed is None:
opt.manual_seed = random.randint(1, 10000)
random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
cudnn.benchmark = True
train_datatset_ = train_dataset(opt.data_path, opt.size_w, opt.size_h, opt.flip)
train_loader = torch.utils.data.DataLoader(dataset=train_datatset_, batch_size=opt.batch_size, shuffle=True,
num_workers=opt.num_workers)
def weights_init(m):
class_name = m.__class__.__name__
if class_name.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
#####################################################################################
# ToDo:如果是胶囊,不要执行 m.bias.data.fill_(0)
#####################################################################################
# m.bias.data.fill_(0)
elif class_name.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
#####################################################################################
# ToDo:init the model
#####################################################################################
# net = UNet(opt.input_nc, opt.output_nc)
# net = Segnet(opt.input_nc, opt.output_nc)
# net = FCN8s(pretrained_net=VGGNet(pretrained=False),n_class=opt.output_nc)
net = CapsuleNet(num_parts=5)
if opt.net != '':
net.load_state_dict(torch.load(opt.net))
#####################################################################################
# ToDo:FCN不用以下方式初始化参数
#####################################################################################
else:
net.apply(weights_init)
if opt.cuda:
net.cuda()
if opt.num_GPU > 1:
net = nn.DataParallel(net)
########### LOSS & OPTIMIZER ##########
# criterion = nn.BCELoss()
#####################################################################################
# ToDo:choose the capsule loss function
#####################################################################################
criterion = CapsuleLoss(height=256,width=256)
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
########### GLOBAL VARIABLES ###########
initial_image = torch.FloatTensor(opt.batch_size, opt.input_nc, opt.size_w, opt.size_h)
semantic_image = torch.FloatTensor(opt.batch_size, opt.input_nc, opt.size_w, opt.size_h)
initial_image = Variable(initial_image)
semantic_image = Variable(semantic_image)
if opt.cuda:
initial_image = initial_image.cuda()
semantic_image = semantic_image.cuda()
def caculate_miou_pa(model,data_path=opt.val_data_path):
"""
计算miou和pa
"""
model.eval()
eval=Evaluator(2)
#batch_size = opt.batch_size
batch_size=2
datatset_=train_dataset(data_path, opt.size_w, opt.size_h, opt.flip)
data_loader = torch.utils.data.DataLoader(dataset=datatset_,
batch_size=batch_size,
shuffle=True,
num_workers=opt.num_workers)
loader = iter(data_loader)
for i in range(0, datatset_.__len__(), batch_size):
initial_image_, semantic_image_, name = loader.next()
if opt.cuda:
initial_image_ = initial_image_.cuda()
semantic_image_ = semantic_image_.cuda()
#####################################################################################
# ToDo:注意分辨模型是否是胶囊网络
#####################################################################################
# semantic_image_pred = model(initial_image_)
part_map, semantic_image_pred = model(initial_image_)
semantic_image_pred = torch.cat((semantic_image_pred, semantic_image_pred, semantic_image_pred), dim=1)
semantic_image_ = semantic_image_.view(-1)
semantic_image_pred = semantic_image_pred.view(-1)
eval.update_matrix(semantic_image_.data.cpu().numpy().astype(np.uint8),
(semantic_image_pred+torch.tensor(0.5)).data.cpu().numpy().astype(np.uint8))
# 计算MIoU
MIoU = eval.Mean_Intersection_over_Union()
# 计算PA
PA = eval.Pixel_Accuracy()
"""
model.eval()不开启BN和Dropout
只有torch.no_grad()可以关闭梯度计算,可用于节省显存
为了确保梯度传播的准确,在loss.backward()前执行optimizer.zero_grad()即可
"""
model.train()
return MIoU, PA
if __name__ == '__main__':
loss_log = []
MIoU_log = []
PA_log = []
#####################################################################################
# ToDo:setting the dir of log
#####################################################################################
# log = open('./checkpoint/Unet/train_Unet_log.txt', 'w')
# log = open('./checkpoint/Segnet/train_Segnet_log.txt', 'w')
# log = open('./checkpoint/FCN/train_FCN_log.txt', 'w')
# log = open('./checkpoint/Capsule/train_Capsule_log.txt', 'w')
log = open('./checkpoint/MultiCapsule/train_Capsule_log.txt', 'w')
start = time.time()
net.train()
print("start training...")
for epoch in range(1, opt.niter+1):
loader = iter(train_loader)
for i in range(0, train_datatset_.__len__(), opt.batch_size):
initial_image_, semantic_image_, name = loader.next()
initial_image.resize_(initial_image_.size()).copy_(initial_image_)
semantic_image.resize_(semantic_image_.size()).copy_(semantic_image_)
# print("semantic_image:", semantic_image.size()) #torch.Size([4, 3, 256, 256])
# dataloader已经把标注的值统一转为0或者1
# print(set(semantic_image.cpu().numpy().reshape(-1)))
#####################################################################################
# ToDo:注意推理的模型
#####################################################################################
# semantic_image_pred = net(initial_image)
part_map, semantic_image_pred = net(initial_image)
#print("semantic_image_pred:",semantic_image_pred.size()) #torch.Size([4, 1, 256, 256])
semantic_image_pred=torch.cat((semantic_image_pred,semantic_image_pred,semantic_image_pred),dim=1)
#print("semantic_image_pred:", semantic_image_pred.size()) #torch.Size([4, 3, 256, 256])
semantic_image = semantic_image.view(-1)
semantic_image_pred = semantic_image_pred.view(-1)
#####################################################################################
# ToDo:check the type of loss{ BCELoss or CapsuleLoss }
#####################################################################################
# loss = criterion(semantic_image_pred, semantic_image)
loss = criterion(semantic_image_pred, part_map, semantic_image)
loss_log.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
########### Logging ##########
if i % opt.log_step == 0:
print('[%d/%d][%d/%d] Loss: %.4f' %
(epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item()))
log.write('[%d/%d][%d/%d] Loss: %.4f' %
(epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item()))
if i % opt.test_step == 0:
MIoU, PA = caculate_miou_pa(net,data_path=opt.val_data_path)
MIoU_log.append(MIoU)
PA_log.append(PA)
print("MIoU:{},PA:{}".format(MIoU,PA))
vutils.save_image(semantic_image_pred.data.reshape(-1,3,256,256), opt.outf + '/fake_samples_epoch_%03d_%03d.png' % (epoch, i),normalize=True)
if epoch % opt.save_epoch == 0:
torch.save(net.state_dict(), '%s/model/netG_%s.pth' % (opt.outf, str(epoch)))
end = time.time()
torch.save(net.state_dict(), '%s/model/netG_final.pth' % opt.outf)
loss_log = np.array(loss_log)
np.save('%s/model/loss.npy' % opt.outf, loss_log)
MIoU_log = np.array(MIoU_log)
np.save('%s/model/MIoU.npy' % opt.outf, MIoU_log)
PA_log = np.array(PA_log)
np.save('%s/model/PA.npy' % opt.outf, PA_log)
print('Program processed ', end - start, 's, ', (end - start)/60, 'min, ', (end - start)/3600, 'h')
log.close()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。