Skip to content

Commit b3a83b8

Browse files
committed
Prep Kron for merge, add detail to attributions note, README.
1 parent 67ef6f0 commit b3a83b8

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
## What's New
1414

15+
## Jan 27, 2025
16+
* Add Kron Optimizer (PSGD w/ Kronecker-factored preconditioner)
17+
* Code from https://github.com/evanatyourservice/kron_torch
18+
* See also https://sites.google.com/site/lixilinx/home/psgd
19+
1520
## Jan 19, 2025
1621
* Fix loading of LeViT safetensor weights, remove conversion code which should have been deactivated
1722
* Add 'SO150M' ViT weights trained with SBB recipes, decent results, but not optimal shape for ImageNet-12k/1k pretrain/ft
@@ -461,6 +466,7 @@ Included optimizers available via `timm.optim.create_optimizer_v2` factory metho
461466
* `adamp` and `sgdp` by [Naver ClovAI](https://github.com/clovaai) - https://arxiv.org/abs/2006.08217
462467
* `adan` an implementation of Adan adapted from https://github.com/sail-sg/Adan - https://arxiv.org/abs/2208.06677
463468
* `adopt` ADOPT adapted from https://github.com/iShohei220/adopt - https://arxiv.org/abs/2411.02853
469+
* `kron` PSGD w/ Kronecker-factored preconditioner from https://github.com/evanatyourservice/kron_torch - https://sites.google.com/site/lixilinx/home/psgd
464470
* `lamb` an implementation of Lamb and LambC (w/ trust-clipping) cleaned up and modified to support use with XLA - https://arxiv.org/abs/1904.00962
465471
* `laprop` optimizer from https://github.com/Z-T-WANG/LaProp-Optimizer - https://arxiv.org/abs/2002.04839
466472
* `lars` an implementation of LARS and LARC (w/ trust-clipping) - https://arxiv.org/abs/1708.03888

timm/optim/kron.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1-
""" PyTorch Implementation of the Kron PSGD optimizer
1+
""" PyTorch Implementation of the Kron (PSGD) optimizer
22
3-
FIXME attribution
4-
* https://github.com/evanatyourservice/kron_torch (direct source)
5-
* https://github.com/lixilinx/psgd_torch (original)
6-
* https://github.com/ClashLuke/HeavyBall (added improvements)
3+
This is a PSGD optimizer using a Kronecker-factored preconditioner.
4+
5+
This impl was adapted from https://github.com/evanatyourservice/kron_torch
6+
by Evan Walters, licensed CC-BY-4.0.
7+
8+
Contributions to above also made by
9+
* Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
10+
* Omead Pooladzandi https://github.com/opooladz
11+
12+
The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
13+
14+
This `timm` impl
15+
* works with a wider variety of torch versions
16+
* fixes some checkpoint save/restore (resume issues)
17+
* adds decoupled weight-decay option
18+
* has some refactoring, cleanup of args, default/group items
19+
* warning about not having opt_einsum (unusable without)
720
821
"""
922
import logging
@@ -30,6 +43,8 @@
3043
except AttributeError:
3144
has_dynamo = False
3245

46+
from ._types import ParamsT
47+
3348
_logger = logging.getLogger(__name__)
3449

3550

@@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer):
85100

86101
def __init__(
87102
self,
88-
params,
103+
params: ParamsT,
89104
lr: float = 0.001,
90105
momentum: float = 0.9,
91106
weight_decay: float = 0.0,
@@ -94,6 +109,8 @@ def __init__(
94109
min_ndim_triangular: int = 2,
95110
memory_save_mode: Optional[str] = None,
96111
momentum_into_precond_update: bool = True,
112+
precond_lr: float = 0.1,
113+
precond_init_scale: float = 1.0,
97114
mu_dtype: Optional[torch.dtype] = None,
98115
precond_dtype: Optional[torch.dtype] = None,
99116
decoupled_decay: bool = False,
@@ -119,8 +136,8 @@ def __init__(
119136
min_ndim_triangular=min_ndim_triangular,
120137
memory_save_mode=memory_save_mode,
121138
momentum_into_precond_update=momentum_into_precond_update,
122-
precond_lr=0.1, # precond lr hardcoded to 0.1
123-
precond_init_scale=1.0, # precond init scale hardcoded to 1.0
139+
precond_lr=precond_lr,
140+
precond_init_scale=precond_init_scale,
124141
mu_dtype=mu_dtype,
125142
precond_dtype=precond_dtype,
126143
decoupled_decay=decoupled_decay,

0 commit comments

Comments
 (0)