1 Star 0 Fork 0

KunCheng-He/Light-SERNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
optim_utils.py 1.24 KB
一键复制 编辑 原始数据 按行查看 历史
KunCheng-He 提交于 2022-09-08 20:39 . 添加学习率的衰减方案
import math
import warnings
from torch.optim.lr_scheduler import _LRScheduler
class ExpLR(_LRScheduler):
""" Pytorch 没有实现 lr = lr * exp(gamma) 的衰减方式
所以继承 StepLR 类改写一下
inputs:
optimizer: 模型所用的优化器
step_size: 更新学习率的步长
gamma: 自行定义的超参数
"""
def __init__(self, optimizer, step_size, gamma, last_epoch=-1, verbose=False):
self.step_size = step_size
self.gamma = gamma
super(ExpLR, self).__init__(optimizer, last_epoch, verbose)
# 参考 StepLR 实现
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * math.exp(self.gamma) for group in self.optimizer.param_groups]
# 参考 StepLR 实现
def _get_closed_form_lr(self):
print(self.base_lrs, self.gamma)
return [base_lr * math.exp(self.gamma) for base_lr in self.base_lrs]
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/byack/sernet.git
git@gitee.com:byack/sernet.git
byack
sernet
Light-SERNet
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385