@@ -116,7 +116,8 @@ def __init__(
116
116
precond_dtype : Optional [torch .dtype ] = None ,
117
117
decoupled_decay : bool = False ,
118
118
flatten : bool = False ,
119
- flatten_start_end : Tuple [int , int ] = (2 , - 1 ),
119
+ flatten_start_dim : int = 2 ,
120
+ flatten_end_dim : int = - 1 ,
120
121
deterministic : bool = False ,
121
122
):
122
123
if not has_opt_einsum :
@@ -144,7 +145,8 @@ def __init__(
144
145
precond_dtype = precond_dtype ,
145
146
decoupled_decay = decoupled_decay ,
146
147
flatten = flatten ,
147
- flatten_start_end = flatten_start_end ,
148
+ flatten_start_dim = flatten_start_dim ,
149
+ flatten_end_dim = flatten_end_dim ,
148
150
)
149
151
super (Kron , self ).__init__ (params , defaults )
150
152
@@ -235,7 +237,7 @@ def step(self, closure=None):
235
237
236
238
flattened = False
237
239
if group ['flatten' ]:
238
- grad = safe_flatten (grad , * group ["flatten_start_end " ])
240
+ grad = safe_flatten (grad , group ["flatten_start_dim" ], group [ "flatten_end_dim " ])
239
241
flattened = True
240
242
241
243
if len (state ) == 0 :
0 commit comments