1 Star 0 Fork 0

spartanbin/ vehicle_trajectory_prediction

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 9.71 KB
一键复制 编辑 原始数据 按行查看 历史
spartanbin 提交于 2022-02-17 17:50 . first commit
import sys
import os
import time
import logging
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
project_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(project_path)
from model import pipNet
from data import highwayTrajDataset
from utils import initLogging, maskedNLL, maskedMSE, maskedNLLTest
class parser_:
def __init__(self):
# General setting------------------------------------------
self.use_cuda = True
self.use_planning = True
self.use_fusion = True
self.train_output_flag = True
self.batch_size = 64
self.learning_rate = 0.001
self.tensorboard = True
# IO setting------------------------------------------
self.grid_size = [25, 5]
self.in_length = 16
self.out_length = 25
self.num_lat_classes = 3
self.num_lon_classes = 2
# Network hyperparameters------------------------------------------
self.temporal_embedding_size = 32
self.encoder_size = 64
self.decoder_size = 128
self.soc_conv_depth = 64
self.soc_conv2_depth = 16
self.dynamics_encoding_size = 32
self.social_context_size = 80
self.fuse_enc_size = 112
# Training setting------------------------------------------
self.name = 'pip_HighD'
self.train_set = project_path + '/processed_data/HighD/train.mat'
self.val_set = project_path + '/processed_data/HighD/val.mat'
self.num_workers = 0
self.pretrain_epochs = 1
self.train_epochs = 4
def train_model():
args = parser_()
## Logging
log_path = "./trained_models/{}/".format(args.name)
os.makedirs(log_path, exist_ok=True)
initLogging(log_file=log_path+'train.log')
if args.tensorboard:
logger = SummaryWriter(log_path + 'train-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs))
logger_val = SummaryWriter(log_path + 'validation-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs))
logging.info("------------- {} -------------".format(args.name))
logging.info("Batch size : {}".format(args.batch_size))
logging.info("Learning rate : {}".format(args.learning_rate))
logging.info("Use Planning Coupled: {}".format(args.use_planning))
logging.info("Use Target Fusion: {}".format(args.use_fusion))
## Initialize network and optimizer
PiP = pipNet(args)
if args.use_cuda:
PiP = PiP.cuda()
optimizer = torch.optim.Adam(PiP.parameters(), lr=args.learning_rate)
crossEnt = torch.nn.BCELoss()
## Initialize training parameters
pretrainEpochs = args.pretrain_epochs
trainEpochs = args.train_epochs
batch_size = args.batch_size
## Initialize data loaders
logging.info("Train dataset: {}".format(args.train_set))
trSet = highwayTrajDataset(path=args.train_set,
targ_enc_size=args.social_context_size+args.dynamics_encoding_size,
grid_size=args.grid_size,
fit_plan_traj=False)
logging.info("Validation dataset: {}".format(args.val_set))
valSet = highwayTrajDataset(path=args.val_set,
targ_enc_size=args.social_context_size+args.dynamics_encoding_size,
grid_size=args.grid_size,
fit_plan_traj=True)
trDataloader = DataLoader(trSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=trSet.collate_fn)
valDataloader = DataLoader(valSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=valSet.collate_fn)
logging.info("DataSet Prepared : {} train data, {} validation data\n".format(len(trSet), len(valSet)))
logging.info("Network structure: {}\n".format(PiP))
## Training process
for epoch_num in range( pretrainEpochs + trainEpochs ):
if epoch_num == 0:
logging.info('Pretrain with MSE loss')
elif epoch_num == pretrainEpochs:
logging.info('Train with NLL loss')
## Variables to track training performance:
avg_time_tr, avg_loss_tr, avg_loss_val = 0, 0, 0
## Training status, reclaim after each epoch
PiP.train()
PiP.train_output_flag = True
for i, data in enumerate(trDataloader):
st_time = time.time()
nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, _ = data
if args.use_cuda:
nbsHist = nbsHist.cuda()
nbsMask = nbsMask.cuda()
planFut = planFut.cuda()
planMask = planMask.cuda()
targsHist = targsHist.cuda()
targsEncMask = targsEncMask.cuda()
lat_enc = lat_enc.cuda()
lon_enc = lon_enc.cuda()
targsFut = targsFut.cuda()
targsFutMask = targsFutMask.cuda()
# Forward pass
fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, lat_enc, lon_enc)
if epoch_num < pretrainEpochs:
# Pre-train with MSE loss to speed up training
l = maskedMSE(fut_pred, targsFut, targsFutMask)
else:
# Train with NLL loss
l = maskedNLL(fut_pred, targsFut, targsFutMask) + crossEnt(lat_pred, lat_enc) + crossEnt(lon_pred, lon_enc)
# Back-prop and update weights
optimizer.zero_grad()
l.backward()
prev_vec_norm = torch.nn.utils.clip_grad_norm_(PiP.parameters(), 10)
optimizer.step()
# Track average train loss and average train time:
batch_time = time.time()-st_time
avg_loss_tr += l.item()
avg_time_tr += batch_time
# For every 100 batches: record loss, validate model, and plot.
if i%100 == 99:
eta = avg_time_tr/100*(len(trSet)/batch_size-i)
epoch_progress = i * batch_size / len(trSet)
logging.info(f"Epoch no:{epoch_num+1}"+
f" | Epoch progress(%):{epoch_progress*100:.2f}"+
f" | Avg train loss:{avg_loss_tr/100:.2f}"+
f" | ETA(s):{int(eta)}")
if args.tensorboard:
logger.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL", avg_loss_tr / 100, (epoch_progress + epoch_num) * 100)
## Validatation during training:
eval_batch_num = 20
with torch.no_grad():
PiP.eval()
PiP.train_output_flag = False
for i, data in enumerate(valDataloader):
nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, _ = data
if args.use_cuda:
nbsHist = nbsHist.cuda()
nbsMask = nbsMask.cuda()
planFut = planFut.cuda()
planMask = planMask.cuda()
targsHist = targsHist.cuda()
targsEncMask = targsEncMask.cuda()
lat_enc = lat_enc.cuda()
lon_enc = lon_enc.cuda()
targsFut = targsFut.cuda()
targsFutMask = targsFutMask.cuda()
if epoch_num < pretrainEpochs:
# During pre-training with MSE loss, validate with MSE for true maneuver class trajectory
PiP.train_output_flag = True
fut_pred, _, _ = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask,
lat_enc, lon_enc)
l = maskedMSE(fut_pred, targsFut, targsFutMask)
else:
# During training with NLL loss, validate with NLL over multi-modal distribution
fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist,
targsEncMask, lat_enc, lon_enc)
l = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut, targsFutMask, avg_along_time=True)
avg_loss_val += l.item()
if i==(eval_batch_num-1):
if args.tensorboard:
logger_val.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL", avg_loss_val / eval_batch_num, (epoch_progress + epoch_num) * 100)
break
# Clear statistic
avg_time_tr, avg_loss_tr, avg_loss_val = 0, 0, 0
# Revert to train mode after in-process evaluation.
PiP.train()
PiP.train_output_flag = True
## Save the model after each epoch______________________________________________________________________________
epoCount = epoch_num + 1
if epoCount < pretrainEpochs:
torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, epoCount, 0))
else:
torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, pretrainEpochs, epoCount - pretrainEpochs))
# All epochs finish________________________________________________________________________________________________
torch.save(PiP.state_dict(), log_path+"{}.tar".format(args.name))
logging.info("Model saved in trained_models/{}/{}.tar\n".format(args.name, args.name))
if __name__ == '__main__':
train_model()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/spartanbin/vehicle_trajectory_prediction.git
git@gitee.com:spartanbin/vehicle_trajectory_prediction.git
spartanbin
vehicle_trajectory_prediction
vehicle_trajectory_prediction
main

搜索帮助