1 Star 0 Fork 0

snow-tyan/classifier

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
show_lr.py 2.16 KB
一键复制 编辑 原始数据 按行查看 历史
snow-tyan 提交于 2021-11-20 14:53 . add all
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR
import torch.optim as optim
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import math
if __name__ == '__main__':
brain_train_data = 1350
adver_train_data = int(51137 * 0.88)
yecaichong_data = 994 # train 994 val 109
lr = 1e-4
mode = 'cosineAnn'
max_epoch = 25
batch_size = 64
ACCUMULATE = 2
iters = math.ceil(adver_train_data / batch_size)
T = iters // ACCUMULATE * max_epoch # cycle
print(iters)
model = resnet18(pretrained=False)
optimizer = optim.SGD(model.parameters(), lr=lr)
if mode == 'cosineAnn':
scheduler = CosineAnnealingLR(optimizer, T_max=T)
elif mode == 'cosineAnnWarm':
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2)
'''
以T_0=5, T_mult=1为例:
T_0:学习率第一次回到初始值的epoch位置.
T_mult:这个控制了学习率回升的速度
- 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
- 5,10,15,20,25,.......处回到最大值
- 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
- 5,15,35,75,155,.......处回到最大值
example:
T_0=5, T_mult=1
'''
plt.figure()
cur_lr_list = []
for epoch in range(max_epoch):
model.train()
print('epoch_{}'.format(epoch))
for batch in range(iters):
if (batch + 1) % ACCUMULATE == 0: # Gradient Accumulate
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# scheduler.step(epoch + batch / iters)
cur_lr = optimizer.param_groups[-1]['lr']
cur_lr_list.append(cur_lr)
# print('cur_lr:', cur_lr)
print('epoch_{}_end'.format(epoch))
# scheduler.step()
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/snow-tyan/classifier.git
git@gitee.com:snow-tyan/classifier.git
snow-tyan
classifier
classifier
master

搜索帮助