代码拉取完成,页面将自动刷新
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR
import torch.optim as optim
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import math
if __name__ == '__main__':
brain_train_data = 1350
adver_train_data = int(51137 * 0.88)
yecaichong_data = 994 # train 994 val 109
lr = 1e-4
mode = 'cosineAnn'
max_epoch = 25
batch_size = 64
ACCUMULATE = 2
iters = math.ceil(adver_train_data / batch_size)
T = iters // ACCUMULATE * max_epoch # cycle
print(iters)
model = resnet18(pretrained=False)
optimizer = optim.SGD(model.parameters(), lr=lr)
if mode == 'cosineAnn':
scheduler = CosineAnnealingLR(optimizer, T_max=T)
elif mode == 'cosineAnnWarm':
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=2, T_mult=2)
'''
以T_0=5, T_mult=1为例:
T_0:学习率第一次回到初始值的epoch位置.
T_mult:这个控制了学习率回升的速度
- 如果T_mult=1,则学习率在T_0,2*T_0,3*T_0,....,i*T_0,....处回到最大值(初始学习率)
- 5,10,15,20,25,.......处回到最大值
- 如果T_mult>1,则学习率在T_0,(1+T_mult)*T_0,(1+T_mult+T_mult**2)*T_0,.....,(1+T_mult+T_mult**2+...+T_0**i)*T0,处回到最大值
- 5,15,35,75,155,.......处回到最大值
example:
T_0=5, T_mult=1
'''
plt.figure()
cur_lr_list = []
for epoch in range(max_epoch):
model.train()
print('epoch_{}'.format(epoch))
for batch in range(iters):
if (batch + 1) % ACCUMULATE == 0: # Gradient Accumulate
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# scheduler.step(epoch + batch / iters)
cur_lr = optimizer.param_groups[-1]['lr']
cur_lr_list.append(cur_lr)
# print('cur_lr:', cur_lr)
print('epoch_{}_end'.format(epoch))
# scheduler.step()
x_list = list(range(len(cur_lr_list)))
plt.plot(x_list, cur_lr_list)
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。