Skip to content

Commit cd351bb

Browse files
committed
[scripts] implement max-change within customized SGD
1 parent d211d33 commit cd351bb

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed
+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer, required
3+
4+
5+
class SgdMaxChange(Optimizer):
6+
r"""Implements stochastic gradient descent (optionally with momentum and max
7+
change).
8+
Nesterov momentum is based on the formula from
9+
`On the importance of initialization and momentum in deep learning`__.
10+
Args:
11+
params (iterable): iterable of parameters to optimize or dicts defining
12+
parameter groups
13+
lr (float): learning rate
14+
momentum (float, optional): momentum factor (default: 0)
15+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
16+
dampening (float, optional): dampening for momentum (default: 0)
17+
nesterov (bool, optional): enables Nesterov momentum (default: False)
18+
max_change_per_layer (float, optional): change in parameters allowed of
19+
any given layer, on any given batch, measured in l2 norm
20+
max_change (float, optional): change in parameters allowed of the whole
21+
model, after applying the per-layer constraint
22+
Example:
23+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
24+
>>> optimizer.zero_grad()
25+
>>> loss_fn(model(input), target).backward()
26+
>>> optimizer.step()
27+
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
28+
.. note::
29+
The implementation of SGD with Momentum/Nesterov subtly differs from
30+
Sutskever et. al. and implementations in some other frameworks.
31+
Considering the specific case of Momentum, the update can be written as
32+
.. math::
33+
\begin{aligned}
34+
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
35+
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
36+
\end{aligned}
37+
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
38+
parameters, gradient, velocity, and momentum respectively.
39+
This is in contrast to Sutskever et. al. and
40+
other frameworks which employ an update of the form
41+
.. math::
42+
\begin{aligned}
43+
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
44+
p_{t+1} & = p_{t} - v_{t+1}.
45+
\end{aligned}
46+
The Nesterov version is analogously modified.
47+
"""
48+
49+
def __init__(self, params, lr=required, momentum=0, dampening=0,
50+
weight_decay=0, nesterov=False, max_change_per_layer=0.75, max_change=1.5):
51+
if lr is not required and lr < 0.0:
52+
raise ValueError("Invalid learning rate: {}".format(lr))
53+
if momentum < 0.0:
54+
raise ValueError("Invalid momentum value: {}".format(momentum))
55+
if weight_decay < 0.0:
56+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
57+
if max_change_per_layer < 0.01:
58+
raise ValueError("Invalid max_change_per_layer value: {}".format(max_change_per_layer))
59+
if max_change < 0.01:
60+
raise ValueError("Invalid max_change value: {}".format(max_change))
61+
62+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
63+
weight_decay=weight_decay, nesterov=nesterov,
64+
max_change_per_layer=max_change_per_layer,
65+
max_change=max_change)
66+
if nesterov and (momentum <= 0 or dampening != 0):
67+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
68+
super(SgdMaxChange, self).__init__(params, defaults)
69+
70+
def __setstate__(self, state):
71+
super(SgdMaxChange, self).__setstate__(state)
72+
for group in self.param_groups:
73+
group.setdefault('nesterov', False)
74+
75+
@torch.no_grad()
76+
def step(self, closure=None):
77+
"""Performs a single optimization step.
78+
Arguments:
79+
closure (callable, optional): A closure that reevaluates the model
80+
and returns the loss.
81+
"""
82+
loss = None
83+
if closure is not None:
84+
with torch.enable_grad():
85+
loss = closure()
86+
change = 0
87+
88+
for group in self.param_groups:
89+
weight_decay = group['weight_decay']
90+
momentum = group['momentum']
91+
dampening = group['dampening']
92+
nesterov = group['nesterov']
93+
max_change_per_layer = group['max_change_per_layer']
94+
max_change = group['max_change']
95+
96+
delta = []
97+
total_norm = 0
98+
99+
for i in range(len(group['params'])):
100+
p = group['params'][i]
101+
if p.grad is None:
102+
continue
103+
d_p = p.grad
104+
if weight_decay != 0:
105+
d_p = d_p.add(p, alpha=weight_decay)
106+
if momentum != 0:
107+
param_state = self.state[p]
108+
if 'momentum_buffer' not in param_state:
109+
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
110+
else:
111+
buf = param_state['momentum_buffer']
112+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
113+
if nesterov:
114+
d_p = d_p.add(buf, alpha=momentum)
115+
else:
116+
d_p = buf
117+
norm = d_p.norm(2).item()
118+
if norm * group['lr'] > max_change_per_layer:
119+
d_p.mul_(max_change_per_layer / (norm * group['lr']))
120+
delta.append(d_p)
121+
total_norm += d_p.norm(2).item() ** 2.
122+
123+
total_norm = total_norm ** 0.5
124+
125+
for i in range(len(group['params'])):
126+
p = group['params'][i]
127+
if p.grad is None:
128+
continue
129+
if total_norm * group['lr'] > max_change:
130+
p.add_(delta[i], alpha=-group['lr'] * max_change / (total_norm * group['lr']))
131+
else:
132+
p.add_(delta[i], alpha=-group['lr'])
133+
134+
change += total_norm * group['lr']
135+
136+
return loss, change

egs/aishell/s10/chain/train.py

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
3535
from model import get_chain_model
3636
from options import get_args
37+
#from sgd_max_change import SgdMaxChange
3738

3839
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
3940
feature, supervision = batch
@@ -301,6 +302,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
301302
else:
302303
valid_dataloader = None
303304

305+
#optimizer = SgdMaxChange(model.parameters(),
304306
optimizer = optim.Adam(model.parameters(),
305307
lr=learning_rate,
306308
weight_decay=5e-4)

0 commit comments

Comments
 (0)