代码拉取完成,页面将自动刷新
同步操作将从 vegee/ESPCN_Learning 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# -*- coding: UTF-8 -*-
import argparse
import copy
import pandas as pd
import pytorch_ssim
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from data_utils import DatasetFromFolder
from model import Net
from psnrmeter import PSNRMeter
def train(data, target, model, optimizer, loss_fn):
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
pred = model(data)
loss = loss_fn(pred, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def valid(val_data, val_target, model):
if torch.cuda.is_available():
val_data = val_data.cuda()
val_target = val_target.cuda()
val_pred = model(val_data)
Instance_PSNR = PSNRMeter()
Instance_PSNR.add(val_pred, val_target)
psnr = Instance_PSNR.value()
ssim = pytorch_ssim.ssim(val_pred.cpu(), val_target.cpu()).item()
return psnr, ssim
if __name__ == "__main__":
# 这里设置的放大倍数要与预处理时创建的数据集相一致
parser = argparse.ArgumentParser(description='Train Super Resolution')
parser.add_argument('--upscale_factor', default=3, type=int, help='super resolution upscale factor')
parser.add_argument('--num_epochs', default=100, type=int, help='super resolution epochs number')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
train_set = DatasetFromFolder('data/train', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
target_transform=transforms.ToTensor())
val_set = DatasetFromFolder('data/val', upscale_factor=UPSCALE_FACTOR, input_transform=transforms.ToTensor(),
target_transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=64, shuffle=False)
model = Net(upscale_factor=UPSCALE_FACTOR)
criterion = nn.MSELoss()
if torch.cuda.is_available():
model = model.cuda()
criterion = criterion.cuda()
print('# parameters:', sum(param.numel() for param in model.parameters()))
optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
result_loss = []
result_psnr = []
result_ssim = []
best_psnr = 0
best_ssim = 0
best_epoch = 1
best_weight = copy.deepcopy(model.state_dict)
# 此处模型权重保留最后一次训练结果,避免有时loss会降下来,但指标不一定是最优的情况。
last_weight = copy.deepcopy(model.state_dict)
for epoch in range(NUM_EPOCHS):
model = model.train()
train_bar = tqdm(train_loader)
loss_epoch = 0
trainbatchnum = 0
for train_data, train_label in train_bar:
loss_bat = train(train_data, train_label, model, optimizer, criterion)
loss_epoch += loss_bat
trainbatchnum += 1
train_bar.set_description(
desc='Train[%d/%d] Loss:%.5f' % ((epoch + 1), NUM_EPOCHS, loss_epoch / trainbatchnum))
model = model.eval()
with torch.no_grad():
val_bar = tqdm(val_loader)
psnr_epoch = 0
ssim_epoch = 0
valbatch = 0
for val_data, val_label in val_bar:
psnr_batch, ssim_batch = valid(val_data, val_label, model)
psnr_epoch += psnr_batch
ssim_epoch += ssim_batch
valbatch += 1
val_bar.set_description('psnr: %.5f dB, ssim: %.5f' % (psnr_epoch / valbatch, ssim_epoch / valbatch))
if (psnr_epoch / valbatch) > best_psnr and (ssim_epoch / valbatch) > best_ssim:
best_epoch = (epoch + 1)
best_psnr = (psnr_epoch / valbatch)
best_ssim = (ssim_epoch / valbatch)
best_weight = copy.deepcopy(model.state_dict())
result_loss.append(loss_epoch / trainbatchnum)
result_psnr.append(psnr_epoch / valbatch)
result_ssim.append(ssim_epoch / valbatch)
scheduler.step()
last_weight = copy.deepcopy(model.state_dict())
data_frame = pd.DataFrame(
data={'Loss': result_loss, 'ValidSet_PSNR': result_psnr, 'ValidSet_SSIM': result_ssim},
index=range(1, NUM_EPOCHS + 1))
data_frame.to_csv('SRF_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
print("数据保存成功!")
torch.save(best_weight, "epochs/ESPCN_best({}).pt".format(best_epoch))
torch.save(last_weight, "epochs/ESPCN_last.pt")
print("模型保存成功!最佳数据——epoch{}, psnr: {:.5f}dB, ssim: {:.5f}".format(best_epoch, best_psnr, best_ssim))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。