1 Star 0 Fork 1

Asher_Z/ResGCNv1

forked from yfsong0709/ResGCNv1 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 4.81 KB
一键复制 编辑 原始数据 按行查看 历史
YFSong 提交于 2020-10-20 14:03 . initial commit
import os, yaml, argparse
from time import strftime
from src import utils as U
from src.generator import Generator
from src.processor import Processor
from src.visualizer import Visualizer
def main():
# Loading Parameters
parser = init_parameters()
args, _ = parser.parse_known_args()
# Updating Parameters (cmd > yaml > default)
args = update_parameters(parser, args)
# Setting save_dir
save_dir = get_save_dir(args)
U.set_logging(save_dir)
with open('{}/config.yaml'.format(save_dir), 'w') as f:
yaml.dump(vars(args), f)
# Processing
if args.generate_data or args.generate_label:
g = Generator(args)
g.start()
elif args.extract or args.visualization:
if args.extract:
p = Processor(args, save_dir)
p.extract()
if args.visualization:
v = Visualizer(args)
v.start()
else:
p = Processor(args, save_dir)
p.start()
def init_parameters():
parser = argparse.ArgumentParser(description='Skeleton-based Action Recognition')
# Setting
parser.add_argument('--config', '-c', type=str, default='', help='ID of the using config', required=True)
parser.add_argument('--gpus', '-g', type=int, nargs='+', default=[], help='Using GPUs')
parser.add_argument('--seed', '-s', type=int, default=1, help='Random seed')
parser.add_argument('--debug', '-db', default=False, action='store_true', help='Debug')
parser.add_argument('--pretrained_path', '-pp', type=str, default='', help='Path to pretrained models')
parser.add_argument('--work_dir', '-w', type=str, default='', help='Work dir')
parser.add_argument('--no_progress_bar', '-np', default=False, action='store_true', help='Do not show progress bar')
parser.add_argument('--path', '-p', type=str, default='', help='Path to save preprocessed skeleton files')
# Processing
parser.add_argument('--resume', '-r', default=False, action='store_true', help='Resume from checkpoint')
parser.add_argument('--evaluate', '-e', default=False, action='store_true', help='Evaluate')
parser.add_argument('--extract', '-ex', default=False, action='store_true', help='Extract')
parser.add_argument('--visualization', '-v', default=False, action='store_true', help='Visualization')
parser.add_argument('--generate_data', '-gd', default=False, action='store_true', help='Generate skeleton data')
parser.add_argument('--generate_label', '-gl', default=False, action='store_true', help='Only generate label')
# Visualization
parser.add_argument('--visualization_class', '-vc', type=int, default=0, help='Class: 1 ~ 60, 0 means true class')
parser.add_argument('--visualization_sample', '-vs', type=int, default=0, help='Sample: 0 ~ batch_size-1')
parser.add_argument('--visualization_frames', '-vf', type=int, nargs='+', default=[], help='Frame: 0 ~ max_frame-1')
# Dataloader
parser.add_argument('--dataset', '-d', type=str, default='', help='Select dataset')
parser.add_argument('--dataset_args', default=dict(), help='Args for creating dataset')
# Model
parser.add_argument('--model_type', '-mt', type=str, default='', help='Model type')
parser.add_argument('--model_args', default=dict(), help='Args for creating model')
# Optimizer
parser.add_argument('--optimizer', '-o', type=str, default='', help='Initial optimizer')
parser.add_argument('--optimizer_args', default=dict(), help='Args for optimizer')
# LR_Scheduler
parser.add_argument('--lr_scheduler', '-ls', type=str, default='', help='Initial learning rate scheduler')
parser.add_argument('--scheduler_args', default=dict(), help='Args for scheduler')
return parser
def update_parameters(parser, args):
if os.path.exists('./configs/{}.yaml'.format(args.config)):
with open('./configs/{}.yaml'.format(args.config), 'r') as f:
try:
yaml_arg = yaml.load(f, Loader=yaml.FullLoader)
except:
yaml_arg = yaml.load(f)
default_arg = vars(args)
for k in yaml_arg.keys():
if k not in default_arg.keys():
raise ValueError('Do NOT exist this parameter {}'.format(k))
parser.set_defaults(**yaml_arg)
else:
raise ValueError('Do NOT exist this file in \'configs\' folder: {}.yaml!'.format(args.config))
return parser.parse_args()
def get_save_dir(args):
if args.debug or args.evaluate or args.extract or args.visualization or args.generate_data or args.generate_label:
save_dir = '{}/temp'.format(args.work_dir)
else:
ct = strftime('%Y-%m-%d %H-%M-%S')
save_dir = '{}/{}_{}_{}/{}'.format(args.work_dir, args.config, args.model_type, args.dataset, ct)
U.create_folder(save_dir)
return save_dir
if __name__ == '__main__':
os.chdir(os.getcwd())
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/asher-z/ResGCNv1.git
git@gitee.com:asher-z/ResGCNv1.git
asher-z
ResGCNv1
ResGCNv1
master

搜索帮助