1 Star 0 Fork 1

yhl41001/CenterNet_Pro_Max

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
train_coco_r50.py 3.96 KB
一键复制 编辑 原始数据 按行查看 历史
FagangJin 提交于 2020-03-20 16:03 . try adding custom nuscenes to train
#
# Copyright (c) 2020 jintian.
#
# This file is part of CenterNet_Pro_Max
# (see jinfagang.github.io).
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import os
import torch
from configs.coco.ct_coco_r50_config import config
from typing import Any, Dict, List
import argparse
from models.data import MetadataCatalog
from models.centernet import build_model
from models.evaluation.coco_evaluation import COCOEvaluator
from models.evaluation.pascal_voc_evaluation import PascalVOCDetectionEvaluator
from models.train.trainer import DefaultTrainer
from models.evaluation.evaluator import DatasetEvaluators
from models.train import hooks
from alfred.utils.log import logger
from alfred.utils.log import logger as logging
import importlib
class Trainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type in ["coco", "coco_panoptic_seg"]:
evaluator_list.append(
COCOEvaluator(
dataset_name, cfg, True,
output_folder, dump=cfg.GLOBAL.DUMP_TRAIN
))
elif evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(
dataset_name, evaluator_type
)
)
elif len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
def train(args):
config.merge_from_list(args.opts)
cfg = config
model = build_model(cfg)
if not os.path.exists(cfg.OUTPUT_DIR):
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
logger.info('output will be saved into: {}'.format(cfg.OUTPUT_DIR))
trainer = Trainer(cfg, model)
trainer.resume_or_load(resume=args.resume)
if cfg.TEST.AUG.ENABLED:
trainer.register_hooks(
[hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
)
return trainer.train()
def default_argument_parser():
parser = argparse.ArgumentParser(description="CenterNet Pro Train")
parser.add_argument(
"--resume",
action="store_true",
help="whether to attempt to resume from the checkpoint directory",
)
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1)
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
parser.add_argument("--dist-url", default="tcp://127.0.0.1:{}".format('9080'))
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
if __name__ == '__main__':
args = default_argument_parser().parse_args()
train(args)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yhl41001/CenterNet_Pro_Max.git
git@gitee.com:yhl41001/CenterNet_Pro_Max.git
yhl41001
CenterNet_Pro_Max
CenterNet_Pro_Max
master

搜索帮助