代码拉取完成,页面将自动刷新
import torch
import argparse
import os
import numpy as np
from torch.backends import cudnn
from model import model
from config.config import cfg, cfg_from_file, cfg_from_list
from prepare_data import *
import sys
import pprint
import platform
import pathlib
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train script.')
parser.add_argument('--weights', dest='weights',
help='initialize with specified model parameters',
default=None, type=str)
parser.add_argument('--resume', dest='resume',
help='initialize with saved solver status',
default=None, type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default="./experiments/config/Office-31/CAN/office31_train_amazon2dslr_cfg.yaml", type=str)
parser.add_argument('--set', dest='set_cfgs',
help='set config keys', default=None,
nargs=argparse.REMAINDER)
parser.add_argument('--method', dest='method',
help='set the method to use',
default='demoDomainTarget', type=str)
# default='CAN', type=str)
parser.add_argument('--exp_name', dest='exp_name',
help='the experiment name',
default='office31_a2', type=str)
# default='exp', type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
def train(args):
bn_domain_map = {}
# method-specific setting
if args.method == 'CAN':
from solver.can_solver import CANSolver as Solver
dataloaders = prepare_data_CAN()
num_domains_bn = 2
elif args.method == 'MMD':
from solver.mmd_solver import MMDSolver as Solver
dataloaders = prepare_data_MMD()
num_domains_bn = 2
elif args.method == 'SingleDomainSource':
from solver.single_domain_solver import SingleDomainSolver as Solver
dataloaders = prepare_data_SingleDomainSource()
num_domains_bn = 1
elif args.method == 'SingleDomainTarget':
from solver.single_domain_solver import SingleDomainSolver as Solver
dataloaders = prepare_data_SingleDomainTarget()
num_domains_bn = 1
elif args.method == 'demo':
from solver.demo_solver import DemoSolver as Solver
print("args.method:", args.method)
dataloaders = prepare_data_demo()
num_domains_bn = 2
else:
raise NotImplementedError("Currently don't support the specified method: %s."
% args.method)
# initialize model
model_state_dict = None
fx_pretrained = True
resume_dict = None
if cfg.RESUME != '':
resume_dict = torch.load(cfg.RESUME)
model_state_dict = resume_dict['model_state_dict']
fx_pretrained = False
elif cfg.WEIGHTS != '':
param_dict = torch.load(cfg.WEIGHTS)
model_state_dict = param_dict['weights']
bn_domain_map = param_dict['bn_domain_map']
fx_pretrained = False
# print("测试这里"*10)
net = model.danet(num_classes=cfg.DATASET.NUM_CLASSES,
state_dict=model_state_dict,
feature_extractor=cfg.MODEL.FEATURE_EXTRACTOR,
frozen=[cfg.TRAIN.STOP_GRAD],
fx_pretrained=fx_pretrained,
dropout_ratio=cfg.TRAIN.DROPOUT_RATIO,
fc_hidden_dims=cfg.MODEL.FC_HIDDEN_DIMS,
num_domains_bn=num_domains_bn)
# print("w我爱这里"*10)
# net = torch.nn.DataParallel(net) ## 多卡训练
if torch.cuda.is_available():
net.cuda()
# initialize solver
train_solver = Solver(net, dataloaders, bn_domain_map=bn_domain_map, resume=resume_dict)
# train
train_solver.solve()
print('Finished!')
if __name__ == '__main__':
cudnn.benchmark = True
args = parse_args()
plt = platform.system()
if plt == 'Windows':
pathlib.PosixPath = pathlib.WindowsPath
print('Called with args:')
print(args)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
if args.resume is not None:
cfg.RESUME = args.resume
if args.weights is not None:
cfg.MODEL = args.weights
if args.exp_name is not None:
cfg.EXP_NAME = args.exp_name
print('Using config:')
pprint.pprint(cfg)
cfg.SAVE_DIR = os.path.join(cfg.SAVE_DIR, cfg.EXP_NAME)
print('Output will be saved to %s.' % cfg.SAVE_DIR)
train(args)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。