@@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
95
95
precond_dtype: Dtype of the preconditioner.
96
96
decoupled_decay: AdamW style decoupled weight decay
97
97
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
98
- flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
98
+ flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
99
+ flatten_end_dim: End of flatten range, defaults to -1.
100
+ stochastic_weight_decay: Enable random modulation of weight decay
99
101
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
100
102
"""
101
103
@@ -118,6 +120,7 @@ def __init__(
118
120
flatten : bool = False ,
119
121
flatten_start_dim : int = 2 ,
120
122
flatten_end_dim : int = - 1 ,
123
+ stochastic_weight_decay : bool = False ,
121
124
deterministic : bool = False ,
122
125
):
123
126
if not has_opt_einsum :
@@ -147,6 +150,7 @@ def __init__(
147
150
flatten = flatten ,
148
151
flatten_start_dim = flatten_start_dim ,
149
152
flatten_end_dim = flatten_end_dim ,
153
+ stochastic_weight_decay = stochastic_weight_decay ,
150
154
)
151
155
super (Kron , self ).__init__ (params , defaults )
152
156
@@ -353,11 +357,15 @@ def step(self, closure=None):
353
357
pre_grad = pre_grad .view (p .shape )
354
358
355
359
# Apply weight decay
356
- if group ["weight_decay" ] != 0 :
360
+ weight_decay = group ["weight_decay" ]
361
+ if weight_decay != 0 :
362
+ if group ["stochastic_weight_decay" ]:
363
+ weight_decay = 2 * self .rng .random () * weight_decay
364
+
357
365
if group ["decoupled_decay" ]:
358
- p .mul_ (1. - group ["lr" ] * group [ " weight_decay" ] )
366
+ p .mul_ (1. - group ["lr" ] * weight_decay )
359
367
else :
360
- pre_grad .add_ (p , alpha = group [ " weight_decay" ] )
368
+ pre_grad .add_ (p , alpha = weight_decay )
361
369
362
370
# Update parameters
363
371
p .add_ (pre_grad , alpha = - group ["lr" ])
0 commit comments