代码拉取完成,页面将自动刷新
import os
import time
from tqdm import tqdm
import torch
import math
from torch.utils.data import DataLoader, RandomSampler
from torch.nn import DataParallel
from nets.attention_model import set_decode_type
from utils.log_utils import log_values, log_values_sl
from utils import move_to
def get_inner_model(model):
return model.module if isinstance(model, DataParallel) else model
def validate(model, dataset, opts):
# Validate
print('Validating...')
if opts.problem == 'tspsl':
cost = rollout_sl(model, dataset, opts)
else:
cost = rollout(model, dataset, opts)
avg_cost = cost.mean()
print('Validation overall avg_cost: {} +- {}'.format(
avg_cost, torch.std(cost) / math.sqrt(len(cost))))
return avg_cost
def rollout(model, dataset, opts):
# Put in greedy evaluation mode!
set_decode_type(model, "greedy")
model.eval()
def eval_model_bat(bat):
with torch.no_grad():
cost, _ = model(move_to(bat, opts.device))
return cost.data.cpu()
return torch.cat([
eval_model_bat(bat)
for bat
in tqdm(DataLoader(dataset, batch_size=opts.eval_batch_size), disable=opts.no_progress_bar, ascii=True)
], 0)
def rollout_sl(model, dataset, opts):
# Put in greedy evaluation mode!
set_decode_type(model, "greedy")
model.eval()
def eval_model_bat(bat):
with torch.no_grad():
cost, _ = model(move_to(bat['nodes_coord'], opts.device))
return cost.data.cpu()
return torch.cat([
eval_model_bat(bat)
for bat
in tqdm(DataLoader(dataset, batch_size=opts.eval_batch_size), disable=opts.no_progress_bar, ascii=True)
], 0)
def clip_grad_norms(param_groups, max_norm=math.inf):
"""
Clips the norms for all param groups to max_norm and returns gradient norms before clipping
:param optimizer:
:param max_norm:
:param gradient_norms_log:
:return: grad_norms, clipped_grad_norms: list with (clipped) gradient norms per group
"""
grad_norms = [
torch.nn.utils.clip_grad_norm_(
group['params'],
max_norm if max_norm > 0 else math.inf, # Inf so no clipping but still call to calc
norm_type=2
)
for group in param_groups
]
grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms
return grad_norms, grad_norms_clipped
def train_epoch(model, optimizer, baseline, lr_scheduler, epoch, val_dataset, problem, tb_logger, opts):
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
step = epoch * (opts.epoch_size // opts.batch_size)
start_time = time.time()
if not opts.no_tensorboard:
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
# Generate new training data for each epoch
training_dataset = baseline.wrap_dataset(problem.make_dataset(
size=opts.graph_size, num_samples=opts.epoch_size, distribution=opts.data_distribution))
training_dataloader = DataLoader(training_dataset, batch_size=opts.batch_size, num_workers=1)
# Put model in train mode!
model.train()
set_decode_type(model, "sampling")
for batch_id, batch in enumerate(tqdm(training_dataloader, disable=opts.no_progress_bar, ascii=True)):
train_batch(
model,
optimizer,
baseline,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
)
step += 1
lr_scheduler.step(epoch)
epoch_duration = time.time() - start_time
print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))
if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1:
print('Saving model and state...')
torch.save(
{
'model': get_inner_model(model).state_dict(),
'optimizer': optimizer.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all(),
'baseline': baseline.state_dict()
},
os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch))
)
avg_reward = validate(model, val_dataset, opts)
if not opts.no_tensorboard:
tb_logger.log_value('val_avg_reward', avg_reward, step)
baseline.epoch_callback(model, epoch)
def train_batch(
model,
optimizer,
baseline,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
):
x, bl_val = baseline.unwrap_batch(batch)
x = move_to(x, opts.device)
bl_val = move_to(bl_val, opts.device) if bl_val is not None else None
# Evaluate model, get costs and log probabilities
cost, log_likelihood = model(x)
# Evaluate baseline, get baseline loss if any (only for critic)
bl_val, bl_loss = baseline.eval(x, cost) if bl_val is None else (bl_val, 0)
# Calculate loss
reinforce_loss = ((cost - bl_val) * log_likelihood).mean()
loss = reinforce_loss + bl_loss
# Perform backward pass and optimization step
optimizer.zero_grad()
loss.backward()
# Clip gradient norms and get (clipped) gradient norms for logging
grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
optimizer.step()
# Logging
if step % int(opts.log_step) == 0:
log_values(cost, grad_norms, epoch, batch_id, step,
log_likelihood, reinforce_loss, bl_loss, tb_logger, opts)
def train_epoch_sl(model, optimizer, lr_scheduler, epoch, train_dataset, val_dataset, problem, tb_logger, opts):
print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], opts.run_name))
step = epoch * (opts.epoch_size // opts.batch_size)
start_time = time.time()
if not opts.no_tensorboard:
tb_logger.log_value('learnrate_pg0', optimizer.param_groups[0]['lr'], step)
# Generate new training data for each epoch
train_dataloader = DataLoader(train_dataset, batch_size=opts.batch_size, num_workers=1, sampler=RandomSampler(train_dataset))
# Put model in train mode!
model.train()
set_decode_type(model, "greedy")
for batch_id, batch in enumerate(tqdm(train_dataloader, disable=opts.no_progress_bar, ascii=True)):
train_batch_sl(
model,
optimizer,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
)
step += 1
lr_scheduler.step(epoch)
epoch_duration = time.time() - start_time
print("Finished epoch {}, took {} s".format(epoch, time.strftime('%H:%M:%S', time.gmtime(epoch_duration))))
if (opts.checkpoint_epochs != 0 and epoch % opts.checkpoint_epochs == 0) or epoch == opts.n_epochs - 1:
print('Saving model and state...')
torch.save(
{
'model': get_inner_model(model).state_dict(),
'optimizer': optimizer.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state_all()
},
os.path.join(opts.save_dir, 'epoch-{}.pt'.format(epoch))
)
avg_reward = validate(model, val_dataset, opts)
if not opts.no_tensorboard:
tb_logger.log_value('val_avg_reward', avg_reward, step)
def train_batch_sl(
model,
optimizer,
epoch,
batch_id,
step,
batch,
tb_logger,
opts
):
nodes_coord = move_to(batch['nodes_coord'], opts.device)
tour_nodes = move_to(batch['tour_nodes'], opts.device)
cost, loss = model(nodes_coord, supervised_mode=True, targets=tour_nodes)
loss = loss.mean() # Take mean of loss across multiple GPUs
# Perform backward pass and optimization step
optimizer.zero_grad()
loss.backward()
# Clip gradient norms and get (clipped) gradient norms for logging
grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
optimizer.step()
# Logging
if step % int(opts.log_step) == 0:
log_values_sl(cost, grad_norms, epoch, batch_id, step,
loss, tb_logger, opts)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。