代码拉取完成,页面将自动刷新
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Yuze (dingyiwei@stu.xmu.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pprint
import shutil
import numpy as np
import random
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from tensorboardX import SummaryWriter
import _init_paths
import models
from utils import ddp_opx
from core.trainer import Trainer
from core.loss import JointsMSELoss
from dataset.dataManager import DataManager
from config import cfg, update_config, get_args_parser
from utils.utils import create_logger, get_optimizer, save_checkpoint, merge_dicts
def main():
args = get_args_parser()
update_config(cfg, args)
ddp_opx.init_distributed_mode(args)
device = torch.device(args.device)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
seed = args.seed + ddp_opx.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# >>>>>>>>>>>>>>>>>>>>>>>>> record log <<<<<<<<<<<<<<<<<<<<<<<<<
writer_dict = None
logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, ddp_opx.get_rank(), 'val' if args.eval else 'train')
if ddp_opx.is_main_process():
logger.info(pprint.pformat(args))
logger.info(cfg)
writer_dict = {
'writer': SummaryWriter(log_dir=tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
# copy model file
this_dir = os.path.dirname(__file__)
shutil.copy2(os.path.join(this_dir, 'lib/models', cfg.MODEL.NAME + '.py'), final_output_dir)
trainer = Trainer(args, cfg)
# >>>>>>>>>>>>>>>>>>>>>>>>> Data <<<<<<<<<<<<<<<<<<<<<<<<<
dataManager = DataManager(args, cfg)
train_loader = dataManager.get_dataloader('train')
valid_loader = dataManager.get_dataloader('val')
begin_epoch = cfg.TRAIN.BEGIN_EPOCH
best_perf = 0.0
best_model = False
last_epoch = -1
criterion = JointsMSELoss(
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
).cuda()
# >>>>>>>>>>>>>>>>>>>>>>>>> model <<<<<<<<<<<<<<<<<<<<<<<<<
if cfg.DATASET.DATASET == 'vcoco':
object_to_target = train_loader.dataset.hoi_data.object_to_action
# object_to_target = list(train_loader.dataset.hoi_data.object_to_action.values())
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, object_to_target, is_train=True
)
else:
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, is_train=True
)
model.cuda()
# one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
# all_records = ddp_opx.all_gather(one_record)
# if ddp_opx.is_main_process():
# logger.info('=> eval model of {}'.format(cfg.TEST.MODEL_FILE))
# all_records = merge_dicts(all_records)
# trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir)
# return
optimizer = get_optimizer(cfg, model.parameters())
for (name, p) in model.named_parameters():
if p.requires_grad:
print(f'optimizer finetune {name}')
optimizer = get_optimizer(cfg, filter(lambda p: p.requires_grad, model.parameters()))
# >>>>>>>>>>>>>>>>>>>>>>>>> eval <<<<<<<<<<<<<<<<<<<<<<<<<
if args.eval:
ckpt_state_dict = torch.load(cfg.TEST.MODEL_FILE, map_location=torch.device('cpu'))
model.load_state_dict(ckpt_state_dict, strict=False)
one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
all_records = ddp_opx.all_gather(one_record)
if ddp_opx.is_main_process():
logger.info('=> eval model of {}'.format(cfg.TEST.MODEL_FILE))
all_records = merge_dicts(all_records)
trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir)
return
# >>>>>>>>>>>>>>>>>>>>>>>>> Resume <<<<<<<<<<<<<<<<<<<<<<<<<
checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
begin_epoch = checkpoint['epoch']
best_perf = checkpoint['perf']
last_epoch = checkpoint['epoch']
if ddp_opx.is_main_process():
logger.info("=> Auto resume loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))
writer_dict['train_global_steps'] = checkpoint['train_global_steps']
writer_dict['valid_global_steps'] = checkpoint['valid_global_steps']
model.load_state_dict(checkpoint['best_state_dict'], strict=False)
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, cfg.TRAIN.END_EPOCH, eta_min=cfg.TRAIN.LR_END, last_epoch=last_epoch)
if args.distributed:
# find_unused_parameters = False if some parameters in the model are frozen else True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
else:
# for single gpu save model.module.state_dict()
model = torch.nn.DataParallel(model, device_ids=[args.gpu])
# >>>>>>>>>>>>>>>>>>>>>>>>> begin to train <<<<<<<<<<<<<<<<<<<<<<<<<
for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
if ddp_opx.is_main_process():
logger.info("=> current learning rate is {:.6f}".format(lr_scheduler.get_last_lr()[0]))
trainer.train_one_epoch(train_loader, model, criterion, optimizer, epoch,
final_output_dir, writer_dict)
# evaluate on validation set
one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
all_records = ddp_opx.all_gather(one_record)
if ddp_opx.is_main_process():
all_records = merge_dicts(all_records)
perf_indicator = trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir, writer_dict)
lr_scheduler.step()
if ddp_opx.is_main_process() and perf_indicator >= best_perf:
best_perf = perf_indicator
best_model = True
else:
best_model = False
if ddp_opx.is_main_process():
logger.info('=> saving checkpoint to {}'.format(final_output_dir))
save_checkpoint({
'epoch': epoch + 1,
'model': cfg.MODEL.NAME,
'state_dict': model.state_dict(),
'best_state_dict': model.module.state_dict(),
'perf': perf_indicator,
'optimizer': optimizer.state_dict(),
'train_global_steps': writer_dict['train_global_steps'],
'valid_global_steps': writer_dict['valid_global_steps'],
}, best_model, final_output_dir)
if ddp_opx.is_main_process():
final_model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> saving final model state to {}'.format(
final_model_state_file)
)
torch.save(model.module.state_dict(), final_model_state_file)
writer_dict['writer'].close()
ddp_opx.cleanup()
print("#####\nTraining Done!\n#####")
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。