Skip to content

Commit f1deec3

Browse files
committed
Fix comment, add 'stochastic weight decay' idea because why not
1 parent d0f28d5 commit f1deec3

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

timm/optim/kron.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
9595
precond_dtype: Dtype of the preconditioner.
9696
decoupled_decay: AdamW style decoupled weight decay
9797
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
98-
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
98+
flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
99+
flatten_end_dim: End of flatten range, defaults to -1.
100+
stochastic_weight_decay: Enable random modulation of weight decay
99101
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
100102
"""
101103

@@ -118,6 +120,7 @@ def __init__(
118120
flatten: bool = False,
119121
flatten_start_dim: int = 2,
120122
flatten_end_dim: int = -1,
123+
stochastic_weight_decay: bool = False,
121124
deterministic: bool = False,
122125
):
123126
if not has_opt_einsum:
@@ -147,6 +150,7 @@ def __init__(
147150
flatten=flatten,
148151
flatten_start_dim=flatten_start_dim,
149152
flatten_end_dim=flatten_end_dim,
153+
stochastic_weight_decay=stochastic_weight_decay,
150154
)
151155
super(Kron, self).__init__(params, defaults)
152156

@@ -353,11 +357,15 @@ def step(self, closure=None):
353357
pre_grad = pre_grad.view(p.shape)
354358

355359
# Apply weight decay
356-
if group["weight_decay"] != 0:
360+
weight_decay = group["weight_decay"]
361+
if weight_decay != 0:
362+
if group["stochastic_weight_decay"]:
363+
weight_decay = 2 * self.rng.random() * weight_decay
364+
357365
if group["decoupled_decay"]:
358-
p.mul_(1. - group["lr"] * group["weight_decay"])
366+
p.mul_(1. - group["lr"] * weight_decay)
359367
else:
360-
pre_grad.add_(p, alpha=group["weight_decay"])
368+
pre_grad.add_(p, alpha=weight_decay)
361369

362370
# Update parameters
363371
p.add_(pre_grad, alpha=-group["lr"])

0 commit comments

Comments
 (0)