-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathschedulers.py
100 lines (77 loc) · 3.3 KB
/
schedulers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from keras.callbacks import Callback
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np
class CyclicLRScheduler(Callback):
def __init__(self, start_lr=1e-3, end_lr=6e-3, step_size=2000, decay=None,
gamma=1.):
self.start_lr = start_lr
self.end_lr = end_lr
self.step_size = step_size
self.decay = decay
self.gamma = gamma
self.history = {}
self.iterations = 0.
def on_train_begin(self, _):
K.set_value(self.model.optimizer.lr, self.cal_lr())
def on_batch_end(self, _, logs={}):
self.iterations += 1
self.history.setdefault('lrs', []).append(
K.get_value(self.model.optimizer.lr))
K.set_value(self.model.optimizer.lr, self.cal_lr())
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
def cal_lr(self):
cycle = np.floor(1 + self.iterations / (2 * self.step_size))
x = np.abs((self.iterations / self.step_size) - (2 * cycle) + 1)
new_lr = self.start_lr + \
(self.end_lr - self.start_lr) * np.maximum(0, 1 - x)
if self.decay == 'fixed':
new_lr /= (2. ** (cycle - 1))
elif self.decay == 'exp':
new_lr *= (self.gamma ** self.iterations)
return new_lr
def plot_lr(self):
plt.plot(range(len(self.history['lrs'])), self.history['lrs'])
plt.xlabel('Iterations')
plt.ylabel('Learning rate')
class SGDRScheduler(Callback):
def __init__(self, start_lr=1e2, end_lr=1e-2, lr_decay=1., cycle_len=1,
cycle_mult=1, steps_per_epoch=10):
self.start_lr = start_lr
self.end_lr = end_lr
self.lr_decay = lr_decay
self.cycle_len = cycle_len
self.cycle_mult = cycle_mult
self.steps_per_epoch = steps_per_epoch
self.history = {}
def on_train_begin(self, _):
K.set_value(self.model.optimizer.lr, self.start_lr)
self.batch_cycle = 0.
def on_batch_end(self, _, logs={}):
self.batch_cycle += 1
self.history.setdefault('lrs', []).append(
K.get_value(self.model.optimizer.lr))
K.set_value(self.model.optimizer.lr, self.cal_lr())
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
def on_train_end(self, _):
self.cycle_len = np.ceil(self.cycle_len * self.cycle_mult)
self.start_lr *= self.lr_decay
def cal_lr(self):
pct = self.batch_cycle / (self.steps_per_epoch * self.cycle_len)
cos_out = 1 + np.cos(np.pi * pct)
return self.end_lr + 0.5 * (self.start_lr - self.end_lr) * cos_out
def plot_lr(self):
plt.plot(range(len(self.history['lrs'])), self.history['lrs'])
plt.xlabel('Iterations')
plt.ylabel('Learning rate')
def fit_cycle(model, X_train, y_train, valid_data=None, b_size=64, n_cycles=1,
cycle_len=1, cycle_mult=1, lr_sched=None, callbacks=None):
cb_list = []
cb_list = cb_list + [lr_sched] if lr_sched is not None else cb_list
cb_list = cb_list + callbacks if callbacks is not None else cb_list
for _ in range(n_cycles):
model.fit(X_train, y_train, epochs=cycle_len, batch_size=b_size,
validation_data=valid_data, callbacks=cb_list)
cycle_len *= cycle_mult