3 Star 4 Fork 1

Gitee 极速下载/SimAM

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
此仓库是为了提升国内下载速度的镜像仓库,每日同步一次。 原始仓库: https://github.com/ZjjConan/SimAM
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
checkpoint.py 1.51 KB
一键复制 编辑 原始数据 按行查看 历史
ZjjConan 提交于 2021-05-31 12:50 . fix err in checkpoint.py
import os
import torch
import shutil
def save_checkpoint(state, is_best, epoch, save_path='./'):
print("=> saving checkpoint '{}'".format(epoch))
torch.save(state, os.path.join(save_path, 'checkpoint.pth.tar'))
if(epoch % 10 == 0):
torch.save(state, os.path.join(save_path, 'checkpoint_%03d.pth.tar' % epoch))
if is_best:
if epoch >= 90:
shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'),
os.path.join(save_path, 'model_best_in_100_epochs.pth.tar'))
else:
shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'),
os.path.join(save_path, 'model_best_in_090_epochs.pth.tar'))
def load_checkpoint(args, model, optimizer=None, verbose=True):
checkpoint = torch.load(args.resume)
start_epoch = 0
best_acc = 0
if "epoch" in checkpoint:
start_epoch = checkpoint['epoch']
if "best_acc" in checkpoint:
best_acc = checkpoint['best_acc']
model.load_state_dict(checkpoint['state_dict'], False)
if optimizer is not None and "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(args.device)
if verbose:
print("=> loading checkpoint '{}' (epoch {})"
.format(args.resume, start_epoch))
return model, optimizer, best_acc, start_epoch
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/mirrors/SimAM.git
git@gitee.com:mirrors/SimAM.git
mirrors
SimAM
SimAM
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385