1
- """ PyTorch Implementation of the Kron PSGD optimizer
1
+ """ PyTorch Implementation of the Kron ( PSGD) optimizer
2
2
3
- FIXME attribution
4
- * https://github.com/evanatyourservice/kron_torch (direct source)
5
- * https://github.com/lixilinx/psgd_torch (original)
6
- * https://github.com/ClashLuke/HeavyBall (added improvements)
3
+ This is a PSGD optimizer using a Kronecker-factored preconditioner.
4
+
5
+ This impl was adapted from https://github.com/evanatyourservice/kron_torch
6
+ by Evan Walters, licensed CC-BY-4.0.
7
+
8
+ Contributions to above also made by
9
+ * Lucas Nestler, added to his https://github.com/ClashLuke/HeavyBall implementation.
10
+ * Omead Pooladzandi https://github.com/opooladz
11
+
12
+ The above work drew from https://github.com/lixilinx/psgd_torch by Xi-Lin Li
13
+
14
+ This `timm` impl
15
+ * works with a wider variety of torch versions
16
+ * fixes some checkpoint save/restore (resume issues)
17
+ * adds decoupled weight-decay option
18
+ * has some refactoring, cleanup of args, default/group items
19
+ * warning about not having opt_einsum (unusable without)
7
20
8
21
"""
9
22
import logging
30
43
except AttributeError :
31
44
has_dynamo = False
32
45
46
+ from ._types import ParamsT
47
+
33
48
_logger = logging .getLogger (__name__ )
34
49
35
50
@@ -85,7 +100,7 @@ class Kron(torch.optim.Optimizer):
85
100
86
101
def __init__ (
87
102
self ,
88
- params ,
103
+ params : ParamsT ,
89
104
lr : float = 0.001 ,
90
105
momentum : float = 0.9 ,
91
106
weight_decay : float = 0.0 ,
@@ -94,6 +109,8 @@ def __init__(
94
109
min_ndim_triangular : int = 2 ,
95
110
memory_save_mode : Optional [str ] = None ,
96
111
momentum_into_precond_update : bool = True ,
112
+ precond_lr : float = 0.1 ,
113
+ precond_init_scale : float = 1.0 ,
97
114
mu_dtype : Optional [torch .dtype ] = None ,
98
115
precond_dtype : Optional [torch .dtype ] = None ,
99
116
decoupled_decay : bool = False ,
@@ -119,8 +136,8 @@ def __init__(
119
136
min_ndim_triangular = min_ndim_triangular ,
120
137
memory_save_mode = memory_save_mode ,
121
138
momentum_into_precond_update = momentum_into_precond_update ,
122
- precond_lr = 0.1 , # precond lr hardcoded to 0.1
123
- precond_init_scale = 1.0 , # precond init scale hardcoded to 1.0
139
+ precond_lr = precond_lr ,
140
+ precond_init_scale = precond_init_scale ,
124
141
mu_dtype = mu_dtype ,
125
142
precond_dtype = precond_dtype ,
126
143
decoupled_decay = decoupled_decay ,
0 commit comments