1 Star 0 Fork 0

张勇建/C-3-Framework

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
Clone or Download
train.py 2.01 KB
Copy Edit Raw Blame History
曾鑫 authored 2019-07-18 21:59 . Update train.py
import os
import numpy as np
import torch
from config import cfg
#------------prepare enviroment------------
seed = cfg.SEED
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
gpus = cfg.GPU_ID
if len(gpus)==1:
torch.cuda.set_device(gpus[0])
torch.backends.cudnn.benchmark = True
#------------prepare data loader------------
data_mode = cfg.DATASET
if data_mode is 'SHHA':
from datasets.SHHA.loading_data import loading_data
from datasets.SHHA.setting import cfg_data
elif data_mode is 'SHHB':
from datasets.SHHB.loading_data import loading_data
from datasets.SHHB.setting import cfg_data
elif data_mode is 'QNRF':
from datasets.QNRF.loading_data import loading_data
from datasets.QNRF.setting import cfg_data
elif data_mode is 'UCF50':
from datasets.UCF50.loading_data import loading_data
from datasets.UCF50.setting import cfg_data
elif data_mode is 'WE':
from datasets.WE.loading_data import loading_data
from datasets.WE.setting import cfg_data
elif data_mode is 'GCC':
from datasets.GCC.loading_data import loading_data
from datasets.GCC.setting import cfg_data
elif data_mode is 'Mall':
from datasets.Mall.loading_data import loading_data
from datasets.Mall.setting import cfg_data
elif data_mode is 'UCSD':
from datasets.UCSD.loading_data import loading_data
from datasets.UCSD.setting import cfg_data
#------------Prepare Trainer------------
net = cfg.NET
if net in ['MCNN', 'AlexNet', 'VGG', 'VGG_DECODER', 'Res50', 'Res101', 'CSRNet','Res101_SFCN']:
from trainer import Trainer
elif net in ['SANet']:
from trainer_for_M2TCC import Trainer # double losses but signle output
elif net in ['CMTL']:
from trainer_for_CMTL import Trainer # double losses and double outputs
elif net in ['PCCNet']:
from trainer_for_M3T3OCC import Trainer
#------------Start Training------------
pwd = os.path.split(os.path.realpath(__file__))[0]
cc_trainer = Trainer(loading_data,cfg_data,pwd)
cc_trainer.forward()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Petrichor_cyj/C-3-Framework.git
git@gitee.com:Petrichor_cyj/C-3-Framework.git
Petrichor_cyj
C-3-Framework
C-3-Framework
python3.x

Search

23e8dbc6 1850385 7e0993f3 1850385