Skip to content

Commit ce5f93f

Browse files
committed
[scripts] implement max-change within customized SGD optimizer
1 parent d211d33 commit ce5f93f

File tree

2 files changed

+137
-9
lines changed

2 files changed

+137
-9
lines changed

egs/aishell/s10/chain/sgd_mc.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer, required
3+
4+
5+
class SGD_MC(Optimizer):
6+
r"""Implements stochastic gradient descent (optionally with momentum).
7+
Nesterov momentum is based on the formula from
8+
`On the importance of initialization and momentum in deep learning`__.
9+
Args:
10+
params (iterable): iterable of parameters to optimize or dicts defining
11+
parameter groups
12+
lr (float): learning rate
13+
momentum (float, optional): momentum factor (default: 0)
14+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
15+
dampening (float, optional): dampening for momentum (default: 0)
16+
nesterov (bool, optional): enables Nesterov momentum (default: False)
17+
Example:
18+
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
19+
>>> optimizer.zero_grad()
20+
>>> loss_fn(model(input), target).backward()
21+
>>> optimizer.step()
22+
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
23+
.. note::
24+
The implementation of SGD with Momentum/Nesterov subtly differs from
25+
Sutskever et. al. and implementations in some other frameworks.
26+
Considering the specific case of Momentum, the update can be written as
27+
.. math::
28+
\begin{aligned}
29+
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
30+
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
31+
\end{aligned}
32+
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
33+
parameters, gradient, velocity, and momentum respectively.
34+
This is in contrast to Sutskever et. al. and
35+
other frameworks which employ an update of the form
36+
.. math::
37+
\begin{aligned}
38+
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
39+
p_{t+1} & = p_{t} - v_{t+1}.
40+
\end{aligned}
41+
The Nesterov version is analogously modified.
42+
"""
43+
44+
def __init__(self, params, lr=required, momentum=0, dampening=0,
45+
weight_decay=0, nesterov=False, max_change_per_layer=0.75, max_change=1.5):
46+
if lr is not required and lr < 0.0:
47+
raise ValueError("Invalid learning rate: {}".format(lr))
48+
if momentum < 0.0:
49+
raise ValueError("Invalid momentum value: {}".format(momentum))
50+
if weight_decay < 0.0:
51+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
52+
53+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
54+
weight_decay=weight_decay, nesterov=nesterov,
55+
max_change_per_layer=max_change_per_layer,
56+
max_change=max_change)
57+
if nesterov and (momentum <= 0 or dampening != 0):
58+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
59+
super(SGD_MC, self).__init__(params, defaults)
60+
61+
def __setstate__(self, state):
62+
super(SGD_MC, self).__setstate__(state)
63+
for group in self.param_groups:
64+
group.setdefault('nesterov', False)
65+
66+
@torch.no_grad()
67+
def step(self, closure=None):
68+
"""Performs a single optimization step.
69+
Arguments:
70+
closure (callable, optional): A closure that reevaluates the model
71+
and returns the loss.
72+
"""
73+
loss = None
74+
if closure is not None:
75+
with torch.enable_grad():
76+
loss = closure()
77+
change = 0
78+
79+
for group in self.param_groups:
80+
weight_decay = group['weight_decay']
81+
momentum = group['momentum']
82+
dampening = group['dampening']
83+
nesterov = group['nesterov']
84+
max_change_per_layer = group['max_change_per_layer']
85+
max_change = group['max_change']
86+
87+
delta = []
88+
total_norm = 0
89+
90+
for i in range(len(group['params'])):
91+
p = group['params'][i]
92+
if p.grad is None:
93+
continue
94+
d_p = p.grad
95+
if weight_decay != 0:
96+
d_p = d_p.add(p, alpha=weight_decay)
97+
if momentum != 0:
98+
param_state = self.state[p]
99+
if 'momentum_buffer' not in param_state:
100+
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
101+
else:
102+
buf = param_state['momentum_buffer']
103+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
104+
if nesterov:
105+
d_p = d_p.add(buf, alpha=momentum)
106+
else:
107+
d_p = buf
108+
norm = d_p.norm(2).item()
109+
if norm * group['lr'] > max_change_per_layer:
110+
d_p.mul_(max_change_per_layer / norm)
111+
delta.append(d_p)
112+
total_norm += d_p.norm(2).item() ** 2.
113+
114+
total_norm = total_norm ** 0.5
115+
116+
for i in range(len(group['params'])):
117+
p = group['params'][i]
118+
if p.grad is None:
119+
continue
120+
if total_norm * group['lr'] > max_change:
121+
p.add_(delta[i], alpha=-group['lr'] * max_change / total_norm)
122+
else:
123+
p.add_(delta[i], alpha=-group['lr'])
124+
125+
change += total_norm * group['lr']
126+
127+
return loss, change

egs/aishell/s10/chain/train.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
3535
from model import get_chain_model
3636
from options import get_args
37+
from sgd_mc import SGD_MC
3738

3839
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
3940
feature, supervision = batch
@@ -67,20 +68,20 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
6768
supervision, nnet_output,
6869
xent_output)
6970
objf = objf_l2_term_weight[0]
71+
change = 0
7072
if training:
7173
optimizer.zero_grad()
7274
objf.backward()
73-
clip_grad_value_(model.parameters(), 5.0)
74-
optimizer.step()
75+
# clip_grad_value_(model.parameters(), 5.0)
76+
_, change = optimizer.step()
7577

7678
objf_l2_term_weight = objf_l2_term_weight.detach().cpu()
7779

7880
total_objf = objf_l2_term_weight[0].item()
7981
total_weight = objf_l2_term_weight[2].item()
8082
total_frames = nnet_output.shape[0]
8183

82-
return total_objf, total_weight, total_frames
83-
84+
return total_objf, total_weight, total_frames, change
8485

8586
def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
8687
total_objf = 0.
@@ -90,7 +91,7 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
9091
model.eval()
9192

9293
for batch_idx, (pseudo_epoch, batch) in enumerate(dataloader):
93-
objf, weight, frames = get_objf(
94+
objf, weight, frames, _ = get_objf(
9495
batch, model, device, criterion, opts, den_graph, False)
9596
total_objf += objf
9697
total_weight += weight
@@ -116,7 +117,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
116117
len(dataloader)) / (len(dataloader) * num_epochs)
117118
_, dropout = _get_dropout_proportions(
118119
dropout_schedule, data_fraction)[0]
119-
curr_batch_objf, curr_batch_weight, curr_batch_frames = get_objf(
120+
curr_batch_objf, curr_batch_weight, curr_batch_frames, curr_batch_change = get_objf(
120121
batch, model, device, criterion, opts, den_graph, True, optimizer, dropout=dropout)
121122

122123
total_objf += curr_batch_objf
@@ -127,13 +128,13 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
127128
logging.info(
128129
'Device ({}) processing batch {}, current pseudo-epoch is {}/{}({:.6f}%), '
129130
'global average objf: {:.6f} over {} '
130-
'frames, current batch average objf: {:.6f} over {} frames, epoch {}'
131+
'frames, current batch average objf: {:.6f} over {} frames, minibatch change: {:.6f}, epoch {}'
131132
.format(
132133
device.index, batch_idx, pseudo_epoch, len(dataloader),
133134
float(pseudo_epoch) / len(dataloader) * 100,
134135
total_objf / total_weight, total_frames,
135136
curr_batch_objf / curr_batch_weight,
136-
curr_batch_frames, current_epoch))
137+
curr_batch_frames, curr_batch_change, current_epoch))
137138

138139
if valid_dataloader and batch_idx % 1000 == 0:
139140
total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf(
@@ -301,7 +302,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
301302
else:
302303
valid_dataloader = None
303304

304-
optimizer = optim.Adam(model.parameters(),
305+
optimizer = SGD_MC(model.parameters(),
305306
lr=learning_rate,
306307
weight_decay=5e-4)
307308

0 commit comments

Comments
 (0)