1 Star 0 Fork 1

stawary/ESPCN_Learning

forked from vegee/ESPCN_Learning 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 4.87 KB
一键复制 编辑 原始数据 按行查看 历史
# -*- coding: UTF-8 -*-
import argparse
import copy
import pandas as pd
import pytorch_ssim
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from data_utils import DatasetFromFolder
from model import Net
from psnrmeter import PSNRMeter
def train(data, target, model, optimizer, loss_fn):
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
pred = model(data)
loss = loss_fn(pred, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def valid(val_data, val_target, model):
if torch.cuda.is_available():
val_data = val_data.cuda()
val_target = val_target.cuda()
val_pred = model(val_data)
Instance_PSNR = PSNRMeter()
Instance_PSNR.add(val_pred, val_target)
psnr = Instance_PSNR.value()
ssim = pytorch_ssim.ssim(val_pred.cpu(), val_target.cpu()).item()
return psnr, ssim
if __name__ == "__main__":
# 这里设置的放大倍数要与预处理时创建的数据集相一致
parser = argparse.ArgumentParser(description='Train Super Resolution')
parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
target_transform=transforms.ToTensor())
val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
target_transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False)
model = Net(upscale_factor=UPSCALE_FACTOR)
criterion = nn.MSELoss()
if torch.cuda.is_available():
model = model.cuda()
criterion = criterion.cuda()
print('# parameters:', sum(param.numel() for param in model.parameters()))
optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
result_loss = []
result_psnr = []
result_ssim = []
best_psnr = 0
best_ssim = 0
best_epoch = 1
best_weight = copy.deepcopy(model.state_dict)
# 此处模型权重保留最后一次训练结果,避免有时loss会降下来,但指标不一定是最优的情况。
last_weight = copy.deepcopy(model.state_dict)
for epoch in range(NUM_EPOCHS):
model = model.train()
train_bar = tqdm(train_loader)
loss_epoch = 0
trainbatchnum = 0
for train_data, train_label in train_bar:
loss_bat = train(train_data, train_label, model, optimizer, criterion)
loss_epoch += loss_bat
trainbatchnum += 1
train_bar.set_description(
desc='Train[%d/%d] Loss:%.5f' % ((epoch + 1), NUM_EPOCHS, loss_epoch / trainbatchnum))
model = model.eval()
with torch.no_grad():
val_bar = tqdm(val_loader)
psnr_epoch = 0
ssim_epoch = 0
valbatch = 0
for val_data, val_label in val_bar:
psnr_batch, ssim_batch = valid(val_data, val_label, model)
psnr_epoch += psnr_batch
ssim_epoch += ssim_batch
valbatch += 1
val_bar.set_description('psnr: %.5f dB, ssim: %.5f' % (psnr_epoch / valbatch, ssim_epoch / valbatch))
if (psnr_epoch / valbatch) > best_psnr and (ssim_epoch / valbatch) > best_ssim:
best_epoch = (epoch + 1)
best_psnr = (psnr_epoch / valbatch)
best_ssim = (ssim_epoch / valbatch)
best_weight = copy.deepcopy(model.state_dict())
result_loss.append(loss_epoch / trainbatchnum)
result_psnr.append(psnr_epoch / valbatch)
result_ssim.append(ssim_epoch / valbatch)
scheduler.step()
last_weight = copy.deepcopy(model.state_dict())
data_frame = pd.DataFrame(
data={'Loss': result_loss, 'ValidSet_PSNR': result_psnr, 'ValidSet_SSIM': result_ssim},
index=range(1, NUM_EPOCHS + 1))
data_frame.to_csv('SRF_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
print("数据保存成功!")
torch.save(best_weight, "epochs/ESPCN_best({}).pt".format(best_epoch))
torch.save(last_weight, "epochs/ESPCN_last.pt")
print("模型保存成功!最佳数据——epoch{}, psnr: {:.5f}dB, ssim: {:.5f}".format(best_epoch, best_psnr, best_ssim))
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/stawary/ESPCN_Learning.git
git@gitee.com:stawary/ESPCN_Learning.git
stawary
ESPCN_Learning
ESPCN_Learning
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385