Skip to content

Commit 1a5e305

Browse files
add min ndim triangular arg to better catch bias and scale params
1 parent 41c2611 commit 1a5e305

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

psgd_jax/kron.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def scale_by_kron(
4343
float, Callable[[int], float]
4444
] = precond_update_prob_schedule(),
4545
max_size_triangular: int = 8192,
46-
max_skew_triangular: int = 10,
46+
max_skew_triangular: int = float('inf'),
47+
min_ndim_triangular: int = 2,
4748
mu_dtype: Optional[Union[str, jnp.dtype]] = None,
4849
precond_dtype: Optional[Union[str, jnp.dtype]] = None,
4950
precision: str = "tensorfloat32",
@@ -60,6 +61,8 @@ def scale_by_kron(
6061
preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps.
6162
max_size_triangular: int, max size for dim's preconditioner to be triangular.
6263
max_skew_triangular: int, max skew for dim's preconditioner to be triangular.
64+
min_ndim_triangular: int, minimum number of dimensions a layer needs to have
65+
triangular preconditioners.
6366
mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator.
6467
Defaults to the same dtype as the parameters.
6568
precond_dtype: optional str or jnp.dtype, dtype of the preconditioner.
@@ -114,6 +117,7 @@ def init_fn(params):
114117
precond_init_scale,
115118
max_size_triangular,
116119
max_skew_triangular,
120+
min_ndim_triangular,
117121
precond_dtype,
118122
)[0]
119123
for t, s in zip(jax.tree.leaves(params), jax.tree.leaves(scanned_layers_))
@@ -193,6 +197,7 @@ def update_fn(updates: base.Updates, state: dict, params: base.Params = None):
193197
precond_init_scale,
194198
max_size_triangular,
195199
max_skew_triangular,
200+
min_ndim_triangular,
196201
precond_dtype,
197202
existing_Q=jax.tree.map(lambda d: d[0], Q) if s else Q,
198203
)
@@ -317,7 +322,8 @@ def kron(
317322
float, Callable[[int], float]
318323
] = precond_update_prob_schedule(),
319324
max_size_triangular: int = 8192,
320-
max_skew_triangular: int = 10,
325+
max_skew_triangular: int = float('inf'),
326+
min_ndim_triangular: int = 2,
321327
mu_dtype: Optional[Union[str, jnp.dtype]] = None,
322328
precond_dtype: Optional[Union[str, jnp.dtype]] = None,
323329
precision: str = "tensorfloat32",
@@ -337,6 +343,8 @@ def kron(
337343
preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps.
338344
max_size_triangular: int, max size for dim's preconditioner to be triangular.
339345
max_skew_triangular: int, max skew for dim's preconditioner to be triangular.
346+
min_ndim_triangular: int, minimum number of dimensions a layer needs to have
347+
triangular preconditioners.
340348
mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator.
341349
Defaults to the same dtype as the parameters.
342350
precond_dtype: optional str or jnp.dtype, dtype of the preconditioner.
@@ -357,6 +365,7 @@ def kron(
357365
b1=b1,
358366
max_size_triangular=max_size_triangular,
359367
max_skew_triangular=max_skew_triangular,
368+
min_ndim_triangular=min_ndim_triangular,
360369
mu_dtype=mu_dtype,
361370
precond_dtype=precond_dtype,
362371
precision=precision,
@@ -424,7 +433,7 @@ def no_calc(_):
424433
return jax.lax.cond(max_abs > 0, calc, no_calc, A)
425434

426435

427-
def _init_Q_exprs(t, scale, max_size, max_skew, dtype, existing_Q=None):
436+
def _init_Q_exprs(t, scale, max_size, max_skew, min_ndim_triangular, dtype, existing_Q=None):
428437
"""
429438
For a scalar or tensor `t`, we initialize its preconditioner `Q` and reusable
430439
contraction expressions for updating `Q` and preconditioning gradient.
@@ -476,7 +485,12 @@ def _init_Q_exprs(t, scale, max_size, max_skew, dtype, existing_Q=None):
476485
# used for getting the subscripts for exprP
477486
piece1P, piece2P, piece3P, piece4P = ([], [], "", "")
478487
for i, size in enumerate(shape):
479-
if size == 1 or size > max_size or size > max_skew * beta_size:
488+
if (
489+
size == 1
490+
or size > max_size
491+
or size > max_skew * beta_size
492+
or len(shape) < min_ndim_triangular
493+
):
480494
# use diagonal matrix as preconditioner for this dim
481495
if existing_Q is None:
482496
Q.append(scale * jnp.ones(size, dtype=dtype))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "psgd-jax"
7-
version = "0.2.0"
7+
version = "0.2.1"
88
description = "An implementation of PSGD optimizer in JAX."
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

0 commit comments

Comments
 (0)