Skip to content

Commit 6b58186

Browse files
committed
[scripts] implement max-change within customized SGD optimizer
1 parent ce5f93f commit 6b58186

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

egs/aishell/s10/chain/sgd_mc.py renamed to egs/aishell/s10/chain/sgd_max_change.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from torch.optim.optimizer import Optimizer, required
33

44

5-
class SGD_MC(Optimizer):
6-
r"""Implements stochastic gradient descent (optionally with momentum).
5+
class SgdMaxChange(Optimizer):
6+
r"""Implements stochastic gradient descent (optionally with momentum and max
7+
change).
78
Nesterov momentum is based on the formula from
89
`On the importance of initialization and momentum in deep learning`__.
910
Args:
@@ -14,6 +15,10 @@ class SGD_MC(Optimizer):
1415
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
1516
dampening (float, optional): dampening for momentum (default: 0)
1617
nesterov (bool, optional): enables Nesterov momentum (default: False)
18+
max_change_per_layer (float, optional): change in parameters allowed of
19+
any given layer, on any given batch, measured in l2 norm
20+
max_change (float, optional): change in parameters allowed of the whole
21+
model, after applying the per-layer constraint
1722
Example:
1823
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
1924
>>> optimizer.zero_grad()
@@ -49,17 +54,21 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
4954
raise ValueError("Invalid momentum value: {}".format(momentum))
5055
if weight_decay < 0.0:
5156
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
57+
if max_change_per_layer < 0.01:
58+
raise ValueError("Invalid max_change_per_layer value: {}".format(max_change_per_layer))
59+
if max_change < 0.01:
60+
raise ValueError("Invalid max_change value: {}".format(max_change))
5261

5362
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
5463
weight_decay=weight_decay, nesterov=nesterov,
5564
max_change_per_layer=max_change_per_layer,
5665
max_change=max_change)
5766
if nesterov and (momentum <= 0 or dampening != 0):
5867
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
59-
super(SGD_MC, self).__init__(params, defaults)
68+
super(SgdMaxChange, self).__init__(params, defaults)
6069

6170
def __setstate__(self, state):
62-
super(SGD_MC, self).__setstate__(state)
71+
super(SgdMaxChange, self).__setstate__(state)
6372
for group in self.param_groups:
6473
group.setdefault('nesterov', False)
6574

@@ -107,7 +116,7 @@ def step(self, closure=None):
107116
d_p = buf
108117
norm = d_p.norm(2).item()
109118
if norm * group['lr'] > max_change_per_layer:
110-
d_p.mul_(max_change_per_layer / norm)
119+
d_p.mul_(max_change_per_layer / (norm * group['lr']))
111120
delta.append(d_p)
112121
total_norm += d_p.norm(2).item() ** 2.
113122

@@ -118,7 +127,7 @@ def step(self, closure=None):
118127
if p.grad is None:
119128
continue
120129
if total_norm * group['lr'] > max_change:
121-
p.add_(delta[i], alpha=-group['lr'] * max_change / total_norm)
130+
p.add_(delta[i], alpha=-group['lr'] * max_change / (total_norm * group['lr']))
122131
else:
123132
p.add_(delta[i], alpha=-group['lr'])
124133

egs/aishell/s10/chain/train.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
3535
from model import get_chain_model
3636
from options import get_args
37-
from sgd_mc import SGD_MC
37+
from sgd_max_change import SgdMaxChange
3838

3939
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
4040
feature, supervision = batch
@@ -168,6 +168,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, crit
168168
dropout,
169169
pseudo_epoch + current_epoch * len(dataloader))
170170

171+
tf_writer.add_scalar(
172+
'train/current_batch_change',
173+
curr_batch_change,
174+
pseudo_epoch + current_epoch * len(dataloader))
175+
171176
state_dict = model.state_dict()
172177
for key, value in state_dict.items():
173178
# skip batchnorm parameters
@@ -302,7 +307,7 @@ def process_job(learning_rate, device_id=None, local_rank=None):
302307
else:
303308
valid_dataloader = None
304309

305-
optimizer = SGD_MC(model.parameters(),
310+
optimizer = SgdMaxChange(model.parameters(),
306311
lr=learning_rate,
307312
weight_decay=5e-4)
308313

0 commit comments

Comments
 (0)