代码拉取完成,页面将自动刷新
import os
import argparse
from torch.backends import cudnn
from config import config, dataset_config, merge_cfg_arg
from dataloder import get_loader
from solver_cycle import Solver_cycleGAN
from solver_makeup import Solver_makeupGAN
def parse_args():
parser = argparse.ArgumentParser(description='Train GAN')
# general
parser.add_argument('--data_path', default='makeup/makeup_final/', type=str, help='training and test data path')
parser.add_argument('--dataset', default='MAKEUP', type=str, help='dataset name, MAKEUP means two domain, MMAKEUP means multi-domain')
parser.add_argument('--gpus', default='0', type=str, help='GPU device to train with')
parser.add_argument('--batch_size', default='1', type=int, help='batch_size')
parser.add_argument('--vis_step', default='1260', type=int, help='steps between visualization')
parser.add_argument('--task_name', default='', type=str, help='task name')
parser.add_argument('--checkpoint', default='', type=str, help='checkpoint to load')
parser.add_argument('--ndis', default='1', type=int, help='train discriminator steps')
parser.add_argument('--LR', default="2e-4", type=float, help='Learning rate')
parser.add_argument('--decay', default='0', type=int, help='epochs number for training')
parser.add_argument('--model', default='makeupGAN', type=str, help='which model to use: cycleGAN/ makeupGAN')
parser.add_argument('--epochs', default='300', type=int, help='nums of epochs')
parser.add_argument('--whichG', default='branch', type=str, help='which Generator to choose, normal/branch, branch means two input branches')
parser.add_argument('--norm', default='SN', type=str, help='normalization of discriminator, SN means spectrum normalization, none means no normalization')
parser.add_argument('--d_repeat', default='3', type=int, help='the repeat Res-block in discriminator')
parser.add_argument('--g_repeat', default='6', type=int, help='the repeat Res-block in Generator')
parser.add_argument('--lambda_cls', default='1', type=float, help='the lambda_cls weight')
parser.add_argument('--lambda_rec', default='10', type=int, help='lambda_A and lambda_B')
parser.add_argument('--lambda_his', default='1', type=float, help='histogram loss on lips')
parser.add_argument('--lambda_skin_1', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin')
parser.add_argument('--lambda_skin_2', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin')
parser.add_argument('--lambda_eye', default='1', type=float, help='histogram loss on eyes equals to lambda_his*lambda_eye')
parser.add_argument('--content_layer', default='r41', type=str, help='vgg layer we use to output features')
parser.add_argument('--lambda_vgg', default='5e-3', type=float, help='the param of vgg loss')
parser.add_argument('--cls_list', default='SYMIX,MAKEMIX', type=str, help='the classes of makeup to train')
parser.add_argument('--direct', action="store_true", default=True, help='direct means to add local cosmetic loss at the first, unified training')
parser.add_argument('--lips', action="store_true", default=True, help='whether to finetune lips color')
parser.add_argument('--skin', action="store_true", default=True, help='whether to finetune foundation color')
parser.add_argument('--eye', action="store_true", default=True, help='whether to finetune eye shadow color')
args = parser.parse_args()
return args
def train_net():
# enable cudnn
cudnn.benchmark = True
data_loaders = get_loader(dataset_config, config, mode="train") # return train&test
#get the solver
if args.model == 'cycleGAN':
solver = Solver_cycleGAN(data_loaders, config, dataset_config)
elif args.model =='makeupGAN':
solver = Solver_makeupGAN(data_loaders, config, dataset_config)
else:
print("model that not support")
exit()
solver.train()
if __name__ == '__main__':
args = parse_args()
print("Call with args:")
print(args)
config = merge_cfg_arg(config, args)
dataset_config.name = args.dataset
print("The config is:")
print(config)
# Create the directories if not exist
if not os.path.exists(config.data_path):
print("No datapath!!")
exit()
if args.data_path != '':
dataset_config.dataset_path = os.path.join(config.data_path, args.data_path)
train_net()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。