Skip to content

Commit 5ccb456

Browse files
committed
[scripts] implement max-change within customized SGD optimizer
1 parent d211d33 commit 5ccb456

File tree

2 files changed

+151
-9
lines changed

2 files changed

+151
-9
lines changed
+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer, required
3+
4+
5+
class SgdMaxChange(Optimizer):
6+
r"""Implements stochastic gradient descent (optionally with momentum and max
7+
change).
8+
Nesterov momentum is based on the formula from
9+
`On the importance of initialization and momentum in deep learning`__.
10+
Args:
11+
params (iterable): iterable of parameters to optimize or dicts defining
12+
parameter groups
13+
lr (float): learning rate
14+
momentum (float, optional): momentum factor (default: 0)
15+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
16+
dampening (float, optional): dampening for momentum (default: 0)
17+
nesterov (bool, optional): enables Nesterov momentum (default: False)
18+
max_change_per_layer (float, optional): change in parameters allowed of
19+
any given layer, on any given batch, measured in l2 norm
20+
max_change (float, optional): change in parameters allowed of the whole
21+
model, after applying the per-layer constraint
22+
Example:
23+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
24+
>>> optimizer.zero_grad()
25+
>>> loss_fn(model(input), target).backward()
26+
>>> optimizer.step()
27+
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
28+
.. note::
29+
The implementation of SGD with Momentum/Nesterov subtly differs from
30+
Sutskever et. al. and implementations in some other frameworks.
31+
Considering the specific case of Momentum, the update can be written as
32+
.. math::
33+
\begin{aligned}
34+
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
35+
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
36+
\end{aligned}
37+
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
38+
parameters, gradient, velocity, and momentum respectively.
39+
This is in contrast to Sutskever et. al. and
40+
other frameworks which employ an update of the form
41+
.. math::
42+
\begin{aligned}
43+
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
44+
p_{t+1} & = p_{t} - v_{t+1}.
45+
\end{aligned}
46+
The Nesterov version is analogously modified.
47+
"""
48+
49+
def __init__(self, params, lr=required, momentum=0, dampening=0,
50+
weight_decay=0, nesterov=False, max_change_per_layer=0.75, max_change=1.5):
51+
if lr is not required and lr < 0.0:
52+
raise ValueError("Invalid learning rate: {}".format(lr))
53+
if momentum < 0.0:
54+
raise ValueError("Invalid momentum value: {}".format(momentum))
55+
if weight_decay < 0.0:
56+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
57+
if max_change_per_layer < 0.01:
58+
raise ValueError("Invalid max_change_per_layer value: {}".format(max_change_per_layer))
59+
if max_change < 0.01:
60+
raise ValueError("Invalid max_change value: {}".format(max_change))
61+
62+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
63+
weight_decay=weight_decay, nesterov=nesterov,
64+
max_change_per_layer=max_change_per_layer,
65+
max_change=max_change)
66+
if nesterov and (momentum <= 0 or dampening != 0):
67+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
68+
super(SgdMaxChange, self).__init__(params, defaults)
69+
70+
def __setstate__(self, state):
71+
super(SgdMaxChange, self).__setstate__(state)
72+
for group in self.param_groups:
73+
group.setdefault('nesterov', False)
74+
75+
@torch.no_grad()
76+
def step(self, closure=None):
77+
"""Performs a single optimization step.
78+
Arguments:
79+
closure (callable, optional): A closure that reevaluates the model
80+
and returns the loss.
81+
"""
82+
loss = None
83+
if closure is not None:
84+
with torch.enable_grad():
85+
loss = closure()
86+
change = 0
87+
88+
for group in self.param_groups:
89+
weight_decay = group['weight_decay']
90+
momentum = group['momentum']
91+
dampening = group['dampening']
92+
nesterov = group['nesterov']
93+
max_change_per_layer = group['max_change_per_layer']
94+
max_change = group['max_change']
95+
96+
delta = []
97+
total_norm = 0
98+
99+
for i in range(len(group['params'])):
100+
p = group['params'][i]
101+
if p.grad is None:
102+
continue
103+
d_p = p.grad
104+
if weight_decay != 0:
105+
d_p = d_p.add(p, alpha=weight_decay)
106+
if momentum != 0:
107+
param_state = self.state[p]
108+
if 'momentum_buffer' not in param_state:
109+
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
110+
else:
111+
buf = param_state['momentum_buffer']
112+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
113+
if nesterov:
114+
d_p = d_p.add(buf, alpha=momentum)
115+
else:
116+
d_p = buf
117+
norm = d_p.norm(2).item()
118+
if norm * group['lr'] > max_change_per_layer:
119+
d_p.mul_(max_change_per_layer / (norm * group['lr']))
120+
delta.append(d_p)
121+
total_norm += d_p.norm(2).item() ** 2.
122+
123+
total_norm = total_norm ** 0.5
124+
125+
for i in range(len(group['params'])):
126+
p = group['params'][i]
127+
if p.grad is None:
128+
continue
129+
if total_norm * group['lr'] > max_change:
130+
p.add_(delta[i], alpha=-group['lr'] * max_change / (total_norm * group['lr']))
131+
else:
132+
p.add_(delta[i], alpha=-group['lr'])
133+
134+
change += total_norm * group['lr']
135+
136+
return loss, change

egs/aishell/s10/chain/train.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
3535
from model import get_chain_model
3636
from options import get_args
37+
from sgd_max_change import SgdMaxChange
3738

3839
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
3940
feature, supervision = batch
@@ -67,20 +68,20 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
6768
supervision, nnet_output,
6869
xent_output)
6970
objf = objf_l2_term_weight[0]
71+
change = 0
7072
if training:
7173
optimizer.zero_grad()
7274
objf.backward()
73-
clip_grad_value_(model.parameters(), 5.0)
74-
optimizer.step()
75+
# clip_grad_value_(model.parameters(), 5.0)
76+
_, change = optimizer.step()
7577

7678
objf_l2_term_weight = objf_l2_term_weight.detach().cpu()
7779

7880
total_objf = objf_l2_term_weight[0].item()
7981
total_weight = objf_l2_term_weight[2].item()
8082
total_frames = nnet_output.shape[0]
8183

82-
return total_objf, total_weight, total_frames
83-
84+
return total_objf, total_weight, total_frames, change
8485

8586
def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
8687
total_objf = 0.
@@ -90,7 +91,7 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
9091
model.eval()
9192

9293
for batch_idx, (pseudo_epoch, batch) in enumerate(dataloader):
93-
objf, weight, frames = get_objf(
94+
objf, weight, frames, _ = get_objf(
9495
batch, model, device, criterion, opts, den_graph, False)
9596
total_objf += objf
9697
total_weight += weight
@@ -116,7 +117,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
116117
len(dataloader)) / (len(dataloader) * num_epochs)
117118
_, dropout = _get_dropout_proportions(
118119
dropout_schedule, data_fraction)[0]
119-
curr_batch_objf, curr_batch_weight, curr_batch_frames = get_objf(
120+
curr_batch_objf, curr_batch_weight, curr_batch_frames, curr_batch_change = get_objf(
120121
batch, model, device, criterion, opts, den_graph, True, optimizer, dropout=dropout)
121122

122123
total_objf += curr_batch_objf
@@ -127,13 +128,13 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
127128
logging.info(
128129
'Device ({}) processing batch {}, current pseudo-epoch is {}/{}({:.6f}%), '
129130
'global average objf: {:.6f} over {} '
130-
'frames, current batch average objf: {:.6f} over {} frames, epoch {}'
131+
'frames, current batch average objf: {:.6f} over {} frames, minibatch change: {:.6f}, epoch {}'
131132
.format(
132133
device.index, batch_idx, pseudo_epoch, len(dataloader),
133134
float(pseudo_epoch) / len(dataloader) * 100,
134135
total_objf / total_weight, total_frames,
135136
curr_batch_objf / curr_batch_weight,
136-
curr_batch_frames, current_epoch))
137+
curr_batch_frames, curr_batch_change, current_epoch))
137138

138139
if valid_dataloader and batch_idx % 1000 == 0:
139140
total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf(
@@ -167,6 +168,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
167168
dropout,
168169
pseudo_epoch + current_epoch * len(dataloader))
169170

171+
tf_writer.add_scalar(
172+
'train/current_batch_change',
173+
curr_batch_change,
174+
pseudo_epoch + current_epoch * len(dataloader))
175+
170176
state_dict = model.state_dict()
171177
for key, value in state_dict.items():
172178
# skip batchnorm parameters
@@ -301,7 +307,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
301307
else:
302308
valid_dataloader = None
303309

304-
optimizer = optim.Adam(model.parameters(),
310+
optimizer = SgdMaxChange(model.parameters(),
305311
lr=learning_rate,
306312
weight_decay=5e-4)
307313

0 commit comments

Comments
 (0)