Skip to content

Commit 6ad162c

Browse files
committed
Change flattening behaviour in Kron
1 parent 93f44d1 commit 6ad162c

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

timm/optim/kron.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class Kron(torch.optim.Optimizer):
9494
mu_dtype: Dtype of the momentum accumulator.
9595
precond_dtype: Dtype of the preconditioner.
9696
decoupled_decay: AdamW style decoupled weight decay
97-
flatten_dim: Flatten dim >= 2 instead of relying on expressions
97+
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
98+
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
9899
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
99100
"""
100101

@@ -114,7 +115,8 @@ def __init__(
114115
mu_dtype: Optional[torch.dtype] = None,
115116
precond_dtype: Optional[torch.dtype] = None,
116117
decoupled_decay: bool = False,
117-
flatten_dim: bool = False,
118+
flatten: bool = False,
119+
flatten_start_end: Tuple[int, int] = (2, -1),
118120
deterministic: bool = False,
119121
):
120122
if not has_opt_einsum:
@@ -141,7 +143,8 @@ def __init__(
141143
mu_dtype=mu_dtype,
142144
precond_dtype=precond_dtype,
143145
decoupled_decay=decoupled_decay,
144-
flatten_dim=flatten_dim,
146+
flatten=flatten,
147+
flatten_start_end=flatten_start_end,
145148
)
146149
super(Kron, self).__init__(params, defaults)
147150

@@ -229,8 +232,11 @@ def step(self, closure=None):
229232

230233
grad = p.grad
231234
state = self.state[p]
232-
if group['flatten_dim']:
233-
grad = grad.view(grad.size(0), -1)
235+
236+
flattened = False
237+
if group['flatten']:
238+
grad = safe_flatten(grad, *group["flatten_start_end"])
239+
flattened = True
234240

235241
if len(state) == 0:
236242
state["step"] = 0
@@ -341,7 +347,7 @@ def step(self, closure=None):
341347

342348
# RMS of pre_grad should be 1.0, so let's cap at 1.1
343349
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
344-
if group['flatten_dim']:
350+
if flattened:
345351
pre_grad = pre_grad.view(p.shape)
346352

347353
# Apply weight decay
@@ -361,6 +367,20 @@ def step(self, closure=None):
361367
return loss
362368

363369

370+
def safe_flatten(tensor, start_dim=0, end_dim=-1):
371+
ndim = tensor.ndim
372+
373+
# Convert negative end_dim to positive and clip to end
374+
end_dim = min(end_dim if end_dim >= 0 else ndim + end_dim, ndim - 1)
375+
376+
# If tensor has fewer dims than start_dim or start > end, return tensor as is
377+
if ndim <= start_dim or start_dim > end_dim:
378+
return tensor
379+
380+
# Now safe to flatten
381+
return tensor.flatten(start_dim, end_dim)
382+
383+
364384
def _init_Q_exprs(
365385
t,
366386
scale,

0 commit comments

Comments
 (0)