-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathtest_lr_scheduler.py
67 lines (54 loc) · 1.79 KB
/
test_lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from basicsr.models.lr_scheduler import CosineAnnealingRestartLR
try:
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import ticker as mtick
except ImportError:
print('Please install matplotlib.')
mpl.use('Agg')
def main():
optim_params = [
{
'params': [torch.zeros(3, 64, 3, 3)],
'lr': 4e-4
},
{
'params': [torch.zeros(3, 64, 3, 3)],
'lr': 2e-4
},
]
optimizer = torch.optim.Adam(optim_params, lr=2e-4, weight_decay=0, betas=(0.9, 0.99))
period = [50000, 100000, 150000, 150000, 150000]
restart_weights = [1, 1, 0.5, 1, 0.5]
scheduler = CosineAnnealingRestartLR(
optimizer,
period,
restart_weights=restart_weights,
eta_min=1e-7,
)
# draw figure
total_iter = 600000
lr_l = list(range(total_iter))
lr_l2 = list(range(total_iter))
for i in range(total_iter):
optimizer.step()
scheduler.step()
lr_l[i] = optimizer.param_groups[0]['lr']
lr_l2[i] = optimizer.param_groups[1]['lr']
mpl.style.use('default')
plt.figure(1)
plt.subplot(111)
plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
plt.title('Cosine Annealing Restart Learning Rate Scheme', fontsize=16, color='k')
plt.plot(list(range(total_iter)), lr_l, linewidth=1.5, label='learning rate 1')
plt.plot(list(range(total_iter)), lr_l2, linewidth=1.5, label='learning rate 2')
plt.legend(loc='upper right', shadow=False)
ax = plt.gca()
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
ax.set_ylabel('Learning Rate')
ax.set_xlabel('Iteration')
fig = plt.gcf()
fig.savefig('test_lr_scheduler.png')
if __name__ == '__main__':
main()