2 Star 0 Fork 0

樱成冢/PEHO_gitee

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 7.84 KB
一键复制 编辑 原始数据 按行查看 历史
樱成冢 提交于 2023-01-02 21:24 . update forward_parallel
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# Modified by Yuze (dingyiwei@stu.xmu.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pprint
import shutil
import numpy as np
import random
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from tensorboardX import SummaryWriter
import _init_paths
import models
from utils import ddp_opx
from core.trainer import Trainer
from core.loss import JointsMSELoss
from dataset.dataManager import DataManager
from config import cfg, update_config, get_args_parser
from utils.utils import create_logger, get_optimizer, save_checkpoint, merge_dicts
def main():
args = get_args_parser()
update_config(cfg, args)
ddp_opx.init_distributed_mode(args)
device = torch.device(args.device)
# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
seed = args.seed + ddp_opx.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# >>>>>>>>>>>>>>>>>>>>>>>>> record log <<<<<<<<<<<<<<<<<<<<<<<<<
writer_dict = None
logger, final_output_dir, tb_log_dir = create_logger(cfg, args.cfg, ddp_opx.get_rank(), 'val' if args.eval else 'train')
if ddp_opx.is_main_process():
logger.info(pprint.pformat(args))
logger.info(cfg)
writer_dict = {
'writer': SummaryWriter(log_dir=tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
# copy model file
this_dir = os.path.dirname(__file__)
shutil.copy2(os.path.join(this_dir, 'lib/models', cfg.MODEL.NAME + '.py'), final_output_dir)
trainer = Trainer(args, cfg)
# >>>>>>>>>>>>>>>>>>>>>>>>> Data <<<<<<<<<<<<<<<<<<<<<<<<<
dataManager = DataManager(args, cfg)
train_loader = dataManager.get_dataloader('train')
valid_loader = dataManager.get_dataloader('val')
begin_epoch = cfg.TRAIN.BEGIN_EPOCH
best_perf = 0.0
best_model = False
last_epoch = -1
criterion = JointsMSELoss(
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
).cuda()
# >>>>>>>>>>>>>>>>>>>>>>>>> model <<<<<<<<<<<<<<<<<<<<<<<<<
if cfg.DATASET.DATASET == 'vcoco':
object_to_target = train_loader.dataset.hoi_data.object_to_action
# object_to_target = list(train_loader.dataset.hoi_data.object_to_action.values())
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, object_to_target, is_train=True
)
else:
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, is_train=True
)
model.cuda()
# one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
# all_records = ddp_opx.all_gather(one_record)
# if ddp_opx.is_main_process():
# logger.info('=> eval model of {}'.format(cfg.TEST.MODEL_FILE))
# all_records = merge_dicts(all_records)
# trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir)
# return
optimizer = get_optimizer(cfg, model.parameters())
for (name, p) in model.named_parameters():
if p.requires_grad:
print(f'optimizer finetune {name}')
optimizer = get_optimizer(cfg, filter(lambda p: p.requires_grad, model.parameters()))
# >>>>>>>>>>>>>>>>>>>>>>>>> eval <<<<<<<<<<<<<<<<<<<<<<<<<
if args.eval:
ckpt_state_dict = torch.load(cfg.TEST.MODEL_FILE, map_location=torch.device('cpu'))
model.load_state_dict(ckpt_state_dict, strict=False)
one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
all_records = ddp_opx.all_gather(one_record)
if ddp_opx.is_main_process():
logger.info('=> eval model of {}'.format(cfg.TEST.MODEL_FILE))
all_records = merge_dicts(all_records)
trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir)
return
# >>>>>>>>>>>>>>>>>>>>>>>>> Resume <<<<<<<<<<<<<<<<<<<<<<<<<
checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')
if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
begin_epoch = checkpoint['epoch']
best_perf = checkpoint['perf']
last_epoch = checkpoint['epoch']
if ddp_opx.is_main_process():
logger.info("=> Auto resume loaded checkpoint '{}' (epoch {})".format(checkpoint_file, checkpoint['epoch']))
writer_dict['train_global_steps'] = checkpoint['train_global_steps']
writer_dict['valid_global_steps'] = checkpoint['valid_global_steps']
model.load_state_dict(checkpoint['best_state_dict'], strict=False)
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, cfg.TRAIN.END_EPOCH, eta_min=cfg.TRAIN.LR_END, last_epoch=last_epoch)
if args.distributed:
# find_unused_parameters = False if some parameters in the model are frozen else True
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
else:
# for single gpu save model.module.state_dict()
model = torch.nn.DataParallel(model, device_ids=[args.gpu])
# >>>>>>>>>>>>>>>>>>>>>>>>> begin to train <<<<<<<<<<<<<<<<<<<<<<<<<
for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
if args.distributed:
train_loader.sampler.set_epoch(epoch)
if ddp_opx.is_main_process():
logger.info("=> current learning rate is {:.6f}".format(lr_scheduler.get_last_lr()[0]))
trainer.train_one_epoch(train_loader, model, criterion, optimizer, epoch,
final_output_dir, writer_dict)
# evaluate on validation set
one_record = trainer.validate(valid_loader, dataManager.valid_dataset, model, final_output_dir)
all_records = ddp_opx.all_gather(one_record)
if ddp_opx.is_main_process():
all_records = merge_dicts(all_records)
perf_indicator = trainer.evaluate_kpts(all_records, dataManager.valid_dataset, final_output_dir, writer_dict)
lr_scheduler.step()
if ddp_opx.is_main_process() and perf_indicator >= best_perf:
best_perf = perf_indicator
best_model = True
else:
best_model = False
if ddp_opx.is_main_process():
logger.info('=> saving checkpoint to {}'.format(final_output_dir))
save_checkpoint({
'epoch': epoch + 1,
'model': cfg.MODEL.NAME,
'state_dict': model.state_dict(),
'best_state_dict': model.module.state_dict(),
'perf': perf_indicator,
'optimizer': optimizer.state_dict(),
'train_global_steps': writer_dict['train_global_steps'],
'valid_global_steps': writer_dict['valid_global_steps'],
}, best_model, final_output_dir)
if ddp_opx.is_main_process():
final_model_state_file = os.path.join(
final_output_dir, 'final_state.pth'
)
logger.info('=> saving final model state to {}'.format(
final_model_state_file)
)
torch.save(model.module.state_dict(), final_model_state_file)
writer_dict['writer'].close()
ddp_opx.cleanup()
print("#####\nTraining Done!\n#####")
if __name__ == '__main__':
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/leijue222/peho_gitee.git
git@gitee.com:leijue222/peho_gitee.git
leijue222
peho_gitee
PEHO_gitee
master

搜索帮助