Skip to content

Commit d0f28d5

Browse files
committed
Change start/end args
1 parent 6ad162c commit d0f28d5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

timm/optim/kron.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def __init__(
116116
precond_dtype: Optional[torch.dtype] = None,
117117
decoupled_decay: bool = False,
118118
flatten: bool = False,
119-
flatten_start_end: Tuple[int, int] = (2, -1),
119+
flatten_start_dim: int = 2,
120+
flatten_end_dim: int = -1,
120121
deterministic: bool = False,
121122
):
122123
if not has_opt_einsum:
@@ -144,7 +145,8 @@ def __init__(
144145
precond_dtype=precond_dtype,
145146
decoupled_decay=decoupled_decay,
146147
flatten=flatten,
147-
flatten_start_end=flatten_start_end,
148+
flatten_start_dim=flatten_start_dim,
149+
flatten_end_dim=flatten_end_dim,
148150
)
149151
super(Kron, self).__init__(params, defaults)
150152

@@ -235,7 +237,7 @@ def step(self, closure=None):
235237

236238
flattened = False
237239
if group['flatten']:
238-
grad = safe_flatten(grad, *group["flatten_start_end"])
240+
grad = safe_flatten(grad, group["flatten_start_dim"], group["flatten_end_dim"])
239241
flattened = True
240242

241243
if len(state) == 0:

0 commit comments

Comments
 (0)