1 Star 0 Fork 0

zhoub86/learning-paradigms-for-tsp

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 8.24 KB
一键复制 编辑 原始数据 按行查看 历史
chaitjo 提交于 2019-11-04 11:26 . Added code/files
import os
import time
from tqdm import tqdm
import torch
import math
from torch.utils.data import DataLoader, RandomSampler
from torch.nn import DataParallel
from nets.attention_model import set_decode_type
from utils.log_utils import log_values, log_values_sl
from utils import move_to
def get_inner_model(model):
return model.module if isinstance(model, DataParallel) else model
def validate(model, dataset, opts):
# Validate
print('Validating...')
if opts.problem == 'tspsl':
cost = rollout_sl(model, dataset, opts)
else:
cost = rollout(model, dataset, opts)
avg_cost = cost.mean()
print('Validation overall avg_cost: {} +- {}'.format(
avg_cost, torch.std(cost) / math.sqrt(len(cost))))
return avg_cost
def rollout(model, dataset, opts):
# Put in greedy evaluation mode!
set_decode_type(model, "greedy")
model.eval()
def eval_model_bat(bat):
with torch.no_grad():
cost, _ = model(move_to(bat, opts.device))
return cost.data.cpu()
return torch.cat([
eval_model_bat(bat)
for bat
in tqdm(DataLoader(dataset, batch_size=opts.eval_batch_size), disable=opts.no_progress_bar, ascii=True)
], 0)
def rollout_sl(model, dataset, opts):
# Put in greedy evaluation mode!
set_decode_type(model, "greedy")
model.eval()
def eval_model_bat(bat):
with torch.no_grad():
cost, _ = model(move_to(bat['nodes_coord'], opts.device))
return cost.data.cpu()
return torch.cat([
eval_model_bat(bat)
for bat
in tqdm(DataLoader(dataset, batch_size=opts.eval_batch_size), disable=opts.no_progress_bar, ascii=True)
], 0)
def clip_grad_norms(param_groups, max_norm=math.inf):
"""
Clips the norms for all param groups to max_norm and returns gradient norms before clipping
:param optimizer:
:param max_norm:
:param gradient_norms_log:
:return: grad_norms, clipped_grad_norms: list with (clipped) gradient norms per group
"""
grad_norms = [
torch.nn.utils.clip_grad_norm_(
group['params'],
max_norm if max_norm > 0 else math.inf, # Inf so no clipping but still call to calc
norm_type=2
)
for group in param_groups
]
grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms
return grad_norms, grad_norms_clipped
def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, problem, tb_logger, opts):
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
step = epoch * (opts.epoch_size // opts.batch_size)
start_time = time.time()
if not opts.no_tensorboard:
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
# Generate new training data for each epoch
training_dataset = baseline.wrap_dataset(problem.make_dataset(
size=opts.graph_size, num_samples=opts.epoch_size, distribution=opts.data_distribution))
training_dataloader = DataLoader(training_dataset, batch_size=opts.batch_size, num_workers=1)
# Put model in train mode!
model.train()
set_decode_type(model, "sampling")
for batch_id, batch in enumerate(tqdm(training_dataloader, disable=opts.no_progress_bar, ascii=True)):
train_batch(
model,
optimizer,
baseline,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
)
step += 1
lr_scheduler.step(epoch)
epoch_duration = time.time() - start_time
print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))
if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1:
print('Saving model and state...')
torch.save(
{
'model': get_inner_model(model).state_dict(),
'optimizer': optimizer.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all(),
'baseline': baseline.state_dict()
},
os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch))
)
avg_reward = validate(model, val_dataset, opts)
if not opts.no_tensorboard:
tb_logger.log_value('val_avg_reward', avg_reward, step)
baseline.epoch_callback(model, epoch)
def train_batch(
model,
optimizer,
baseline,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
):
x, bl_val = baseline.unwrap_batch(batch)
x = move_to(x, opts.device)
bl_val = move_to(bl_val, opts.device) if bl_val is not None else None
# Evaluate model, get costs and log probabilities
cost, log_likelihood = model(x)
# Evaluate baseline, get baseline loss if any (only for critic)
bl_val, bl_loss = baseline.eval(x, cost) if bl_val is None else (bl_val, 0)
# Calculate loss
reinforce_loss = ((cost - bl_val) * log_likelihood).mean()
loss = reinforce_loss + bl_loss
# Perform backward pass and optimization step
optimizer.zero_grad()
loss.backward()
# Clip gradient norms and get (clipped) gradient norms for logging
grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
optimizer.step()
# Logging
if step % int(opts.log_step) == 0:
log_values(cost, grad_norms, epoch, batch_id, step,
log_likelihood, reinforce_loss, bl_loss, tb_logger, opts)
def train_epoch_sl(model, optimizer, lr_scheduler, epoch, train_dataset, val_dataset, problem, tb_logger, opts):
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
step = epoch * (opts.epoch_size // opts.batch_size)
start_time = time.time()
if not opts.no_tensorboard:
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
# Generate new training data for each epoch
train_dataloader = DataLoader(train_dataset, batch_size=opts.batch_size, num_workers=1, sampler=RandomSampler(train_dataset))
# Put model in train mode!
model.train()
set_decode_type(model, "greedy")
for batch_id, batch in enumerate(tqdm(train_dataloader, disable=opts.no_progress_bar, ascii=True)):
train_batch_sl(
model,
optimizer,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
)
step += 1
lr_scheduler.step(epoch)
epoch_duration = time.time() - start_time
print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))
if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1:
print('Saving model and state...')
torch.save(
{
'model': get_inner_model(model).state_dict(),
'optimizer': optimizer.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all()
},
os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch))
)
avg_reward = validate(model, val_dataset, opts)
if not opts.no_tensorboard:
tb_logger.log_value('val_avg_reward', avg_reward, step)
def train_batch_sl(
model,
optimizer,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
):
nodes_coord = move_to(batch['nodes_coord'], opts.device)
tour_nodes = move_to(batch['tour_nodes'], opts.device)
cost, loss = model(nodes_coord, supervised_mode=True, targets=tour_nodes)
loss = loss.mean() # Take mean of loss across multiple GPUs
# Perform backward pass and optimization step
optimizer.zero_grad()
loss.backward()
# Clip gradient norms and get (clipped) gradient norms for logging
grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
optimizer.step()
# Logging
if step % int(opts.log_step) == 0:
log_values_sl(cost, grad_norms, epoch, batch_id, step,
loss, tb_logger, opts)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zhoub86/learning-paradigms-for-tsp.git
git@gitee.com:zhoub86/learning-paradigms-for-tsp.git
zhoub86
learning-paradigms-for-tsp
learning-paradigms-for-tsp
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385