代码拉取完成,页面将自动刷新
import os
import torch
import shutil
def save_checkpoint(state, is_best, epoch, save_path='./'):
print("=> saving checkpoint '{}'".format(epoch))
torch.save(state, os.path.join(save_path, 'checkpoint.pth.tar'))
if(epoch % 10 == 0):
torch.save(state, os.path.join(save_path, 'checkpoint_%03d.pth.tar' % epoch))
if is_best:
if epoch >= 90:
shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'),
os.path.join(save_path, 'model_best_in_100_epochs.pth.tar'))
else:
shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'),
os.path.join(save_path, 'model_best_in_090_epochs.pth.tar'))
def load_checkpoint(args, model, optimizer=None, verbose=True):
checkpoint = torch.load(args.resume)
start_epoch = 0
best_acc = 0
if "epoch" in checkpoint:
start_epoch = checkpoint['epoch']
if "best_acc" in checkpoint:
best_acc = checkpoint['best_acc']
model.load_state_dict(checkpoint['state_dict'], False)
if optimizer is not None and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(args.device)
if verbose:
print("=> loading checkpoint '{}' (epoch {})"
.format(args.resume, start_epoch))
return model, optimizer, best_acc, start_epoch
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。