代码拉取完成,页面将自动刷新
# @Author: yican, yelanlan
# @Date: 2020-06-16 20:43:36
# @Last Modified by: yican
# @Last Modified time: 2020-06-30 10:09:04
# Standard libraries
import math
# Third party libraries
from torch.optim import lr_scheduler
class WarmRestart(lr_scheduler.CosineAnnealingLR):
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
Set the learning rate of each parameter group using a cosine annealing schedule,
When last_epoch=-1, sets initial lr as lr.
This can't support scheduler.step(epoch). please keep epoch=None.
"""
def __init__(self, optimizer, T_max=10, T_mult=2, eta_min=0, last_epoch=-1):
"""implements SGDR
Parameters:
----------
T_max : int
Maximum number of epochs.
T_mult : int
Multiplicative factor of T_max.
eta_min : int
Minimum learning rate. Default: 0.
last_epoch : int
The index of last epoch. Default: -1.
"""
self.T_mult = T_mult
super().__init__(optimizer, T_max, eta_min, last_epoch)
def get_lr(self):
if self.last_epoch == self.T_max:
self.last_epoch = 0
self.T_max *= self.T_mult
return [
self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
for base_lr in self.base_lrs
]
def warm_restart(scheduler, T_mult=2):
"""warm restart policy
Parameters:
----------
T_mult: int
default is 2, Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
Examples:
--------
>>> # some other operations(note the order of operations)
>>> scheduler.step()
>>> scheduler = warm_restart(scheduler, T_mult=2)
>>> optimizer.step()
"""
if scheduler.last_epoch == scheduler.T_max:
scheduler.last_epoch = -1
scheduler.T_max *= T_mult
return scheduler
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。