1 Star 0 Fork 0

xiaomenshen123/KAIR

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main_train_msrresnet_gan.py 7.91 KB
一键复制 编辑 原始数据 按行查看 历史
Kai Zhang 提交于 2019-12-31 12:17 . Update main_train_msrresnet_gan.py
import os.path
import math
import argparse
import time
import random
import numpy as np
from collections import OrderedDict
import logging
from torch.utils.data import DataLoader
import torch
from utils import utils_logger
from utils import utils_image as util
from utils import utils_option as option
from data.select_dataset import define_Dataset
from models.select_model import define_Model
'''
# --------------------------------------------
# training code for GAN-based model, such as ESRGAN, DPSRGAN
# --------------------------------------------
# Kai Zhang (cskaizhang@gmail.com)
# github: https://github.com/cszn/KAIR
# --------------------------------------------
# https://github.com/xinntao/BasicSR
# --------------------------------------------
'''
def main(json_path='options/train_msrresnet_gan.json'):
'''
# ----------------------------------------
# Step--1 (prepare opt)
# ----------------------------------------
'''
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, default=json_path, help='Path to option JSON file.')
opt = option.parse(parser.parse_args().opt, is_train=True)
util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key))
# ----------------------------------------
# update opt
# ----------------------------------------
# -->-->-->-->-->-->-->-->-->-->-->-->-->-
init_iterG, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G')
init_iterD, init_path_D = option.find_last_checkpoint(opt['path']['models'], net_type='D')
opt['path']['pretrained_netG'] = init_path_G
opt['path']['pretrained_netD'] = init_path_D
current_step = max(init_iterG, init_iterD)
# opt['path']['pretrained_netG'] = ''
# current_step = 0
border = opt['scale']
# --<--<--<--<--<--<--<--<--<--<--<--<--<-
# ----------------------------------------
# save opt to a '../option.json' file
# ----------------------------------------
option.save(opt)
# ----------------------------------------
# return None for missing key
# ----------------------------------------
opt = option.dict_to_nonedict(opt)
# ----------------------------------------
# configure logger
# ----------------------------------------
logger_name = 'train'
utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log'))
logger = logging.getLogger(logger_name)
logger.info(option.dict2str(opt))
# ----------------------------------------
# seed
# ----------------------------------------
seed = opt['train']['manual_seed']
if seed is None:
seed = random.randint(1, 10000)
logger.info('Random seed: {}'.format(seed))
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
'''
# ----------------------------------------
# Step--2 (creat dataloader)
# ----------------------------------------
'''
# ----------------------------------------
# 1) create_dataset
# 2) creat_dataloader for train and test
# ----------------------------------------
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
train_set = define_Dataset(dataset_opt)
train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size']))
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
train_loader = DataLoader(train_set,
batch_size=dataset_opt['dataloader_batch_size'],
shuffle=dataset_opt['dataloader_shuffle'],
num_workers=dataset_opt['dataloader_num_workers'],
drop_last=True,
pin_memory=True)
elif phase == 'test':
test_set = define_Dataset(dataset_opt)
test_loader = DataLoader(test_set, batch_size=1,
shuffle=False, num_workers=1,
drop_last=False, pin_memory=True)
else:
raise NotImplementedError("Phase [%s] is not recognized." % phase)
'''
# ----------------------------------------
# Step--3 (initialize model)
# ----------------------------------------
'''
model = define_Model(opt)
model.init_train()
logger.info(model.info_network())
logger.info(model.info_params())
'''
# ----------------------------------------
# Step--4 (main training)
# ----------------------------------------
'''
for epoch in range(1000000): # keep running
for i, train_data in enumerate(train_loader):
current_step += 1
# -------------------------------
# 1) update learning rate
# -------------------------------
model.update_learning_rate(current_step)
# -------------------------------
# 2) feed patch pairs
# -------------------------------
model.feed_data(train_data)
# -------------------------------
# 3) optimize parameters
# -------------------------------
model.optimize_parameters(current_step)
# -------------------------------
# 4) training information
# -------------------------------
if current_step % opt['train']['checkpoint_print'] == 0:
logs = model.current_log() # such as loss
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(epoch, current_step, model.current_learning_rate())
for k, v in logs.items(): # merge log information into message
message += '{:s}: {:.3e} '.format(k, v)
logger.info(message)
# -------------------------------
# 5) save model
# -------------------------------
if current_step % opt['train']['checkpoint_save'] == 0:
logger.info('Saving the model.')
model.save(current_step)
# -------------------------------
# 6) testing
# -------------------------------
if current_step % opt['train']['checkpoint_test'] == 0:
avg_psnr = 0.0
idx = 0
for test_data in test_loader:
idx += 1
image_name_ext = os.path.basename(test_data['L_path'][0])
img_name, ext = os.path.splitext(image_name_ext)
img_dir = os.path.join(opt['path']['images'], img_name)
util.mkdir(img_dir)
model.feed_data(test_data)
model.test()
visuals = model.current_visuals()
E_img = util.tensor2uint(visuals['E'])
H_img = util.tensor2uint(visuals['H'])
# -----------------------
# save estimated image E
# -----------------------
save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step))
util.imsave(E_img, save_img_path)
# -----------------------
# calculate PSNR
# -----------------------
current_psnr = util.calculate_psnr(E_img, H_img, border=border)
logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr))
avg_psnr += current_psnr
avg_psnr = avg_psnr / idx
# testing log
logger.info('<epoch:{:3d}, iter:{:8,d}, Average PSNR : {:<.2f}dB\n'.format(epoch, current_step, avg_psnr))
logger.info('Saving the final model.')
model.save('latest')
logger.info('End of training.')
if __name__ == '__main__':
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xiaomenshen123/KAIR.git
git@gitee.com:xiaomenshen123/KAIR.git
xiaomenshen123
KAIR
KAIR
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385