代码拉取完成,页面将自动刷新
import math
import warnings
from torch.optim.lr_scheduler import _LRScheduler
class ExpLR(_LRScheduler):
""" Pytorch 没有实现 lr = lr * exp(gamma) 的衰减方式
所以继承 StepLR 类改写一下
inputs:
optimizer: 模型所用的优化器
step_size: 更新学习率的步长
gamma: 自行定义的超参数
"""
def __init__(self, optimizer, step_size, gamma, last_epoch=-1, verbose=False):
self.step_size = step_size
self.gamma = gamma
super(ExpLR, self).__init__(optimizer, last_epoch, verbose)
# 参考 StepLR 实现
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.", UserWarning)
if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * math.exp(self.gamma) for group in self.optimizer.param_groups]
# 参考 StepLR 实现
def _get_closed_form_lr(self):
print(self.base_lrs, self.gamma)
return [base_lr * math.exp(self.gamma) for base_lr in self.base_lrs]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。