@@ -94,7 +94,8 @@ class Kron(torch.optim.Optimizer):
94
94
mu_dtype: Dtype of the momentum accumulator.
95
95
precond_dtype: Dtype of the preconditioner.
96
96
decoupled_decay: AdamW style decoupled weight decay
97
- flatten_dim: Flatten dim >= 2 instead of relying on expressions
97
+ flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
98
+ flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
98
99
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
99
100
"""
100
101
@@ -114,7 +115,8 @@ def __init__(
114
115
mu_dtype : Optional [torch .dtype ] = None ,
115
116
precond_dtype : Optional [torch .dtype ] = None ,
116
117
decoupled_decay : bool = False ,
117
- flatten_dim : bool = False ,
118
+ flatten : bool = False ,
119
+ flatten_start_end : Tuple [int , int ] = (2 , - 1 ),
118
120
deterministic : bool = False ,
119
121
):
120
122
if not has_opt_einsum :
@@ -141,7 +143,8 @@ def __init__(
141
143
mu_dtype = mu_dtype ,
142
144
precond_dtype = precond_dtype ,
143
145
decoupled_decay = decoupled_decay ,
144
- flatten_dim = flatten_dim ,
146
+ flatten = flatten ,
147
+ flatten_start_end = flatten_start_end ,
145
148
)
146
149
super (Kron , self ).__init__ (params , defaults )
147
150
@@ -229,8 +232,11 @@ def step(self, closure=None):
229
232
230
233
grad = p .grad
231
234
state = self .state [p ]
232
- if group ['flatten_dim' ]:
233
- grad = grad .view (grad .size (0 ), - 1 )
235
+
236
+ flattened = False
237
+ if group ['flatten' ]:
238
+ grad = safe_flatten (grad , * group ["flatten_start_end" ])
239
+ flattened = True
234
240
235
241
if len (state ) == 0 :
236
242
state ["step" ] = 0
@@ -341,7 +347,7 @@ def step(self, closure=None):
341
347
342
348
# RMS of pre_grad should be 1.0, so let's cap at 1.1
343
349
pre_grad .mul_ (torch .clamp (1.1 / (pre_grad .square ().mean ().sqrt_ () + 1e-8 ), max = 1.0 ))
344
- if group [ 'flatten_dim' ] :
350
+ if flattened :
345
351
pre_grad = pre_grad .view (p .shape )
346
352
347
353
# Apply weight decay
@@ -361,6 +367,20 @@ def step(self, closure=None):
361
367
return loss
362
368
363
369
370
+ def safe_flatten (tensor , start_dim = 0 , end_dim = - 1 ):
371
+ ndim = tensor .ndim
372
+
373
+ # Convert negative end_dim to positive and clip to end
374
+ end_dim = min (end_dim if end_dim >= 0 else ndim + end_dim , ndim - 1 )
375
+
376
+ # If tensor has fewer dims than start_dim or start > end, return tensor as is
377
+ if ndim <= start_dim or start_dim > end_dim :
378
+ return tensor
379
+
380
+ # Now safe to flatten
381
+ return tensor .flatten (start_dim , end_dim )
382
+
383
+
364
384
def _init_Q_exprs (
365
385
t ,
366
386
scale ,
0 commit comments