1 Star 0 Fork 1

monkeyfx/detection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
tyranrexmain.py 3.23 KB
一键复制 编辑 原始数据 按行查看 历史
JJZHK 提交于 2020-09-05 09:59 . 📝
'''
@Author : JJZHK (myjjzhk@126.com)
@File : tyranrexmain.py
@Project : Detection
@Created : 30th October 2019 2:42:43 pm
@Last Modified : 5th December 2019 4:30:25 pm
@Modified By : JJZHK (myjjzhk@126.com>)
@Copyright : 2017 - 2019
@Description :
'''
'''
512: SSD
vgg16 -> bs, 1024, 32, 32
resnet50 -> bs, 1024, 32, 32
resnet152 -> bs, 1024, 32, 32
detnet59 -> bs, 1024, 32, 32
448: YOLOV1
vgg16 -> bs, 1024, 28, 28
resnet50 -> bs, 1024, 28, 28
resnet152 -> bs, 1024, 28, 28
detnet59 -> bs, 1024, 28, 28
mb1 -> bs, 1024, 14, 14
mbn2 -> bs, 1024, 14, 14
300: SSD
vgg16 -> bs, 1024, 19, 19
resnet50 -> bs, 1024, 19, 19
resnet152 -> bs, 1024, 19, 19
detnet59 -> bs, 1024, 19, 19
mbn1 -> bs, 1024, 10, 10
mbn2 -> bs, 1024, 10, 10
YOLO: You Only Look Once
SSD: Single Shot Multibox Detector
RFB: Receptive Field Block
FSSD:Feature Fusion Single Shot Multibox Detector
'''
from ELib.solver import TyranrexSolver
from jjzhk.config import ZKCFG
import torch
import os
import argparse
def parse_args(argv=None):
parser = argparse.ArgumentParser(
description='SSD & RFB Project')
parser.add_argument('-dataroot', default='/Users/JJZHK/data/', type=str, help='')
parser.add_argument('-model', default='ssd', type=str, help='')
parser.add_argument('-datatype', default='voc', type=str, help='')
parser.add_argument('-net', default='vgg16', type=str, help='')
parser.add_argument('-phase', default='train', type=str, help='')
global args
args = parser.parse_args(argv)
if torch.cuda.is_available():
torch.set_default_tensor_type('torch.cuda.FloatTensor')
def main(model, datatype, net, dataroot, phase):
base_file = os.path.join("%s.yml" % datatype)
try:
config_file = os.path.join(datatype, model, "%s.yml" % net)
except Exception as e:
raise Exception("file does not exit")
config = ZKCFG(cfgfile=config_file, basefile=os.path.join(base_file), rootpath="cfgs")
config.BASE.DATAROOT = dataroot
config.BASE.DATA_ROOT = os.path.join(config.BASE.DATAROOT, config.BASE.DATA_ROOT)
solver = TyranrexSolver(config)
print('model: %s, backbone: %s' % (model, net))
if phase == 'train':
solver.train()
elif phase == 'eval':
solver.eval()
else:
solver.test()
if __name__ == '__main__':
# model = "fssd" # args.model
# datatype = "coco" # args.datatype
# net = "mobilenetv1" # args.net
# dataroot = '/Users/JJZHK/data/' # args.dataroot
# phase = 'test' # args.phase
# model = "rfb" #args.model
# datatype = "coco" # args.datatype
# net = "darknet19" # args.net
# dataroot = '/Users/JJZHK/data/' # args.dataroot
# phase = 'eval' # args.phase
# model = "ssd" # args.model
# datatype = "voc" # args.datatype
# net = "resnet152" # args.net
# dataroot = '/Users/JJZHK/data/' # args.dataroot
# phase = 'train' # args.phase
parse_args()
model = args.model
datatype = args.datatype
net = args.net
dataroot = args.dataroot
phase = args.phase
main(model, datatype, net, dataroot, phase)
# 1, 512, 19, 19
# 1, 1024, 10, 10
# 1, 512, 5, 5
# 1, 256, 3, 3
# 1, 256, 2, 2
# 1, 128, 1, 1
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/monkeyfx/detection.git
git@gitee.com:monkeyfx/detection.git
monkeyfx
detection
detection
master

搜索帮助