1 Star 0 Fork 0

张勇建/C-3-Framework

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
trainer_for_CMTL.py 12.84 KB
一键复制 编辑 原始数据 按行查看 历史
隐匿人海 提交于 2019-07-02 17:08 . py3 first version
import numpy as np
import torch
from torch import optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from models.M2T2OCC import CrowdCounter
from config import cfg
from misc.utils import *
import pdb
class Trainer():
def __init__(self, dataloader, cfg_data, pwd):
self.cfg_data = cfg_data
self.data_mode = cfg.DATASET
self.exp_name = cfg.EXP_NAME
self.exp_path = cfg.EXP_PATH
self.pwd = pwd
self.net_name = cfg.NET
self.train_loader, self.val_loader, self.restore_transform = dataloader()
if self.net_name in ['CMTL']:
# use for gt's class labeling
self.max_gt_count = 0.
self.min_gt_count = 0x7f7f7f
self.num_classes = 10
self.bin_val = 0.
self.pre_max_min_bin_val()
ce_weights = torch.from_numpy(self.pre_weights()).float()
loss_1_fn = nn.MSELoss()
loss_2_fn = nn.BCELoss(weight=ce_weights)
self.net = CrowdCounter(cfg.GPU_ID, self.net_name,loss_1_fn,loss_2_fn).cuda()
self.optimizer = optim.Adam(self.net.CCN.parameters(), lr=cfg.LR, weight_decay=1e-4)
# self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
self.scheduler = StepLR(self.optimizer, step_size=cfg.NUM_EPOCH_LR_DECAY, gamma=cfg.LR_DECAY)
self.train_record = {'best_mae': 1e20, 'best_mse': 1e20, 'best_model_name': ''}
self.timer = {'iter time': Timer(), 'train time': Timer(), 'val time': Timer()}
self.i_tb = 0
self.epoch = 0
if cfg.PRE_GCC:
self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))
if cfg.RESUME:
latest_state = torch.load(cfg.RESUME_PATH)
self.net.load_state_dict(latest_state['net'])
self.optimizer.load_state_dict(latest_state['optimizer'])
self.scheduler.load_state_dict(latest_state['scheduler'])
self.epoch = latest_state['epoch'] + 1
self.i_tb = latest_state['i_tb']
self.train_record = latest_state['train_record']
self.exp_path = latest_state['exp_path']
self.exp_name = latest_state['exp_name']
self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)
def pre_max_min_bin_val(self):
for i, data in enumerate(self.train_loader, 0):
if i < 50:
# for getting the max and min people count
_, gt_map = data
for j in range(0, gt_map.size()[0]):
temp_count = gt_map[j].sum() / self.cfg_data.LOG_PARA
if temp_count > self.max_gt_count:
self.max_gt_count = temp_count
elif temp_count < self.min_gt_count:
self.min_gt_count = temp_count
print( '[max_gt: %.2f min_gt: %.2f]' % (self.max_gt_count, self.min_gt_count) )
self.bin_val = (self.max_gt_count - self.min_gt_count)/float(self.num_classes)
def pre_weights(self):
count_class_hist = np.zeros(self.num_classes)
for i, data in enumerate(self.train_loader, 0):
if i < 100:
_, gt_map = data
for j in range(0, gt_map.size()[0]):
temp_count = gt_map[j].sum() / self.cfg_data.LOG_PARA
class_idx = min(int(temp_count/self.bin_val), self.num_classes-1)
count_class_hist[class_idx] += 1
wts = count_class_hist
wts = 1-wts/(sum(wts));
wts = wts/sum(wts);
print( 'pre_wts:' )
print( wts )
return wts
def online_assign_gt_class_labels(self, gt_map_batch):
batch = gt_map_batch.size()[0]
# pdb.set_trace()
label = np.zeros((batch, self.num_classes), dtype=np.int)
for i in range(0, batch):
# pdb.set_trace()
gt_count = (gt_map_batch[i].sum().item() / self.cfg_data.LOG_PARA)
# generate gt's label same as implement of CMTL by Viswa
gt_class_label = np.zeros(self.num_classes, dtype=np.int)
# bin_val = ((self.max_gt_count - self.min_gt_count)/float(self.num_classes))
class_idx = min(int(gt_count/self.bin_val), self.num_classes-1)
gt_class_label[class_idx] = 1
# pdb.set_trace()
label[i] = gt_class_label.reshape(1, self.num_classes)
return torch.from_numpy(label).float()
def forward(self):
# self.validate_V1()
for epoch in range(self.epoch, cfg.MAX_EPOCH):
self.epoch = epoch
if epoch > cfg.LR_DECAY_START:
self.scheduler.step()
# training
self.timer['train time'].tic()
self.train()
self.timer['train time'].toc(average=False)
print( 'train time: {:.2f}s'.format(self.timer['train time'].diff) )
print( '=' * 20 )
# validation
if epoch % cfg.VAL_FREQ == 0 or epoch > cfg.VAL_DENSE_START:
self.timer['val time'].tic()
if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50']:
self.validate_V1()
elif self.data_mode is 'WE':
self.validate_V2()
elif self.data_mode is 'GCC':
self.validate_V3()
self.timer['val time'].toc(average=False)
print( 'val time: {:.2f}s'.format(self.timer['val time'].diff) )
def train(self): # training for all datasets
self.net.train()
for i, data in enumerate(self.train_loader, 0):
# train net
self.timer['iter time'].tic()
img, gt_map = data
img = Variable(img).cuda()
gt_map = Variable(gt_map).cuda()
gt_label = self.online_assign_gt_class_labels(gt_map)
gt_label = Variable(gt_label).cuda()
self.optimizer.zero_grad()
pred_map = self.net(img, gt_map, gt_label)
loss1,loss2 = self.net.loss
loss = loss1+loss2
# loss = loss1
loss.backward()
self.optimizer.step()
if (i + 1) % cfg.PRINT_FREQ == 0:
self.i_tb += 1
self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
self.writer.add_scalar('train_loss1', loss1.item(), self.i_tb)
self.writer.add_scalar('train_loss2', loss2.item(), self.i_tb)
self.timer['iter time'].toc(average=False)
print( '[ep %d][it %d][loss %.8f, %.8f, %.8f][lr %.4f][%.2fs]' % \
(self.epoch + 1, i + 1, loss.item(),loss1.item(),loss2.item(), self.optimizer.param_groups[0]['lr'] * 10000,
self.timer['iter time'].diff) )
print( ' [cnt: gt: %.1f pred: %.2f]' % (gt_map[0].sum().data/self.cfg_data.LOG_PARA, pred_map[0].sum().data/self.cfg_data.LOG_PARA) )
def validate_V1(self): # validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50
self.net.eval()
losses = AverageMeter()
maes = AverageMeter()
mses = AverageMeter()
for vi, data in enumerate(self.val_loader, 0):
img, gt_map = data
with torch.no_grad():
img = Variable(img).cuda()
gt_map = Variable(gt_map).cuda()
gt_label = self.online_assign_gt_class_labels(gt_map)
gt_label = Variable(gt_label).cuda()
pred_map = self.net.forward(img, gt_map, gt_label)
pred_map = pred_map.data.cpu().numpy()
gt_map = gt_map.data.cpu().numpy()
pred_cnt = np.sum(pred_map) / self.cfg_data.LOG_PARA
gt_count = np.sum(gt_map) / self.cfg_data.LOG_PARA
loss1,loss2 = self.net.loss
# loss = loss1.item()+loss2.item()
loss = loss1.item()
losses.update(loss)
maes.update(abs(gt_count - pred_cnt))
mses.update((gt_count - pred_cnt) * (gt_count - pred_cnt))
if vi == 0:
vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
mae = maes.avg
mse = np.sqrt(mses.avg)
loss = losses.avg
self.writer.add_scalar('val_loss', loss, self.epoch + 1)
self.writer.add_scalar('mae', mae, self.epoch + 1)
self.writer.add_scalar('mse', mse, self.epoch + 1)
self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
[mae, mse, loss],self.train_record,self.log_txt)
print_summary(self.exp_name, [mae, mse, loss], self.train_record)
def validate_V2(self): # validate_V2 for WE
self.net.eval()
losses = AverageCategoryMeter(5)
maes = AverageCategoryMeter(5)
for i_sub, i_loader in enumerate(self.val_loader, 0):
for vi, data in enumerate(i_loader, 0):
img, gt_map = data
with torch.no_grad():
img = Variable(img).cuda()
gt_map = Variable(gt_map).cuda()
pred_map = self.net.forward(img, gt_map)
pred_map = pred_map.data.cpu().numpy()
gt_map = gt_map.data.cpu().numpy()
for i_img in range(pred_map.shape[0]):
pred_cnt = np.sum(pred_map[i_img])/self.cfg_data.LOG_PARA
gt_count = np.sum(gt_map[i_img])/self.cfg_data.LOG_PARA
losses.update(self.net.loss.item(),i_sub)
maes.update(abs(gt_count-pred_cnt),i_sub)
if vi == 0:
vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
mae = np.average(maes.avg)
loss = np.average(losses.avg)
self.writer.add_scalar('val_loss', loss, self.epoch + 1)
self.writer.add_scalar('mae', mae, self.epoch + 1)
self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
[mae, 0, loss],self.train_record,self.log_txt)
print_summary(self.exp_name, [mae, 0, loss], self.train_record)
def validate_V3(self): # validate_V3 for GCC
self.net.eval()
losses = AverageMeter()
maes = AverageMeter()
mses = AverageMeter()
c_maes = {'level': AverageCategoryMeter(9), 'time': AverageCategoryMeter(8), 'weather': AverageCategoryMeter(7)}
c_mses = {'level': AverageCategoryMeter(9), 'time': AverageCategoryMeter(8), 'weather': AverageCategoryMeter(7)}
for vi, data in enumerate(self.val_loader, 0):
img, gt_map, attributes_pt = data
with torch.no_grad():
img = Variable(img).cuda()
gt_map = Variable(gt_map).cuda()
pred_map = self.net.forward(img, gt_map)
pred_map = pred_map.data.cpu().numpy()
gt_map = gt_map.data.cpu().numpy()
for i_img in range(pred_map.shape[0]):
pred_cnt = np.sum(pred_map) / self.cfg_data.LOG_PARA
gt_count = np.sum(gt_map) / self.cfg_data.LOG_PARA
s_mae = abs(gt_count - pred_cnt)
s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)
losses.update(self.net.loss.item())
maes.update(s_mae)
mses.update(s_mse)
c_maes['level'].update(s_mae, attributes_pt[i_img][0])
c_mses['level'].update(s_mse, attributes_pt[i_img][0])
c_maes['time'].update(s_mae, attributes_pt[i_img][1] / 3)
c_mses['time'].update(s_mse, attributes_pt[i_img][1] / 3)
c_maes['weather'].update(s_mae, attributes_pt[i_img][2])
c_mses['weather'].update(s_mse, attributes_pt[i_img][2])
if vi == 0:
vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)
loss = losses.avg
mae = maes.avg
mse = np.sqrt(mses.avg)
self.writer.add_scalar('val_loss', loss, self.epoch + 1)
self.writer.add_scalar('mae', mae, self.epoch + 1)
self.writer.add_scalar('mse', mse, self.epoch + 1)
self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
[mae, mse, loss],self.train_record,self.log_txt)
c_mses['level'] = np.sqrt(c_mses['level'].avg)
c_mses['time'] = np.sqrt(c_mses['time'].avg)
c_mses['weather'] = np.sqrt(c_mses['weather'].avg)
print_GCC_summary(self.exp_name, [mae, mse, loss], self.train_record, c_maes, c_mses)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Petrichor_cyj/C-3-Framework.git
git@gitee.com:Petrichor_cyj/C-3-Framework.git
Petrichor_cyj
C-3-Framework
C-3-Framework
python3.x

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385