2
2
from torch .optim .optimizer import Optimizer , required
3
3
4
4
5
- class SGD_MC (Optimizer ):
6
- r"""Implements stochastic gradient descent (optionally with momentum).
5
+ class SgdMaxChange (Optimizer ):
6
+ r"""Implements stochastic gradient descent (optionally with momentum and max
7
+ change).
7
8
Nesterov momentum is based on the formula from
8
9
`On the importance of initialization and momentum in deep learning`__.
9
10
Args:
@@ -14,6 +15,10 @@ class SGD_MC(Optimizer):
14
15
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
15
16
dampening (float, optional): dampening for momentum (default: 0)
16
17
nesterov (bool, optional): enables Nesterov momentum (default: False)
18
+ max_change_per_layer (float, optional): change in parameters allowed of
19
+ any given layer, on any given batch, measured in l2 norm
20
+ max_change (float, optional): change in parameters allowed of the whole
21
+ model, after applying the per-layer constraint
17
22
Example:
18
23
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
19
24
>>> optimizer.zero_grad()
@@ -49,17 +54,21 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
49
54
raise ValueError ("Invalid momentum value: {}" .format (momentum ))
50
55
if weight_decay < 0.0 :
51
56
raise ValueError ("Invalid weight_decay value: {}" .format (weight_decay ))
57
+ if max_change_per_layer < 0.01 :
58
+ raise ValueError ("Invalid max_change_per_layer value: {}" .format (max_change_per_layer ))
59
+ if max_change < 0.01 :
60
+ raise ValueError ("Invalid max_change value: {}" .format (max_change ))
52
61
53
62
defaults = dict (lr = lr , momentum = momentum , dampening = dampening ,
54
63
weight_decay = weight_decay , nesterov = nesterov ,
55
64
max_change_per_layer = max_change_per_layer ,
56
65
max_change = max_change )
57
66
if nesterov and (momentum <= 0 or dampening != 0 ):
58
67
raise ValueError ("Nesterov momentum requires a momentum and zero dampening" )
59
- super (SGD_MC , self ).__init__ (params , defaults )
68
+ super (SgdMaxChange , self ).__init__ (params , defaults )
60
69
61
70
def __setstate__ (self , state ):
62
- super (SGD_MC , self ).__setstate__ (state )
71
+ super (SgdMaxChange , self ).__setstate__ (state )
63
72
for group in self .param_groups :
64
73
group .setdefault ('nesterov' , False )
65
74
@@ -107,7 +116,7 @@ def step(self, closure=None):
107
116
d_p = buf
108
117
norm = d_p .norm (2 ).item ()
109
118
if norm * group ['lr' ] > max_change_per_layer :
110
- d_p .mul_ (max_change_per_layer / norm )
119
+ d_p .mul_ (max_change_per_layer / ( norm * group [ 'lr' ]) )
111
120
delta .append (d_p )
112
121
total_norm += d_p .norm (2 ).item () ** 2.
113
122
@@ -118,7 +127,7 @@ def step(self, closure=None):
118
127
if p .grad is None :
119
128
continue
120
129
if total_norm * group ['lr' ] > max_change :
121
- p .add_ (delta [i ], alpha = - group ['lr' ] * max_change / total_norm )
130
+ p .add_ (delta [i ], alpha = - group ['lr' ] * max_change / ( total_norm * group [ 'lr' ]) )
122
131
else :
123
132
p .add_ (delta [i ], alpha = - group ['lr' ])
124
133
0 commit comments