@@ -43,7 +43,8 @@ def scale_by_kron(
4343 float , Callable [[int ], float ]
4444 ] = precond_update_prob_schedule (),
4545 max_size_triangular : int = 8192 ,
46- max_skew_triangular : int = 10 ,
46+ max_skew_triangular : int = float ('inf' ),
47+ min_ndim_triangular : int = 2 ,
4748 mu_dtype : Optional [Union [str , jnp .dtype ]] = None ,
4849 precond_dtype : Optional [Union [str , jnp .dtype ]] = None ,
4950 precision : str = "tensorfloat32" ,
@@ -60,6 +61,8 @@ def scale_by_kron(
6061 preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps.
6162 max_size_triangular: int, max size for dim's preconditioner to be triangular.
6263 max_skew_triangular: int, max skew for dim's preconditioner to be triangular.
64+ min_ndim_triangular: int, minimum number of dimensions a layer needs to have
65+ triangular preconditioners.
6366 mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator.
6467 Defaults to the same dtype as the parameters.
6568 precond_dtype: optional str or jnp.dtype, dtype of the preconditioner.
@@ -114,6 +117,7 @@ def init_fn(params):
114117 precond_init_scale ,
115118 max_size_triangular ,
116119 max_skew_triangular ,
120+ min_ndim_triangular ,
117121 precond_dtype ,
118122 )[0 ]
119123 for t , s in zip (jax .tree .leaves (params ), jax .tree .leaves (scanned_layers_ ))
@@ -193,6 +197,7 @@ def update_fn(updates: base.Updates, state: dict, params: base.Params = None):
193197 precond_init_scale ,
194198 max_size_triangular ,
195199 max_skew_triangular ,
200+ min_ndim_triangular ,
196201 precond_dtype ,
197202 existing_Q = jax .tree .map (lambda d : d [0 ], Q ) if s else Q ,
198203 )
@@ -317,7 +322,8 @@ def kron(
317322 float , Callable [[int ], float ]
318323 ] = precond_update_prob_schedule (),
319324 max_size_triangular : int = 8192 ,
320- max_skew_triangular : int = 10 ,
325+ max_skew_triangular : int = float ('inf' ),
326+ min_ndim_triangular : int = 2 ,
321327 mu_dtype : Optional [Union [str , jnp .dtype ]] = None ,
322328 precond_dtype : Optional [Union [str , jnp .dtype ]] = None ,
323329 precision : str = "tensorfloat32" ,
@@ -337,6 +343,8 @@ def kron(
337343 preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps.
338344 max_size_triangular: int, max size for dim's preconditioner to be triangular.
339345 max_skew_triangular: int, max skew for dim's preconditioner to be triangular.
346+ min_ndim_triangular: int, minimum number of dimensions a layer needs to have
347+ triangular preconditioners.
340348 mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator.
341349 Defaults to the same dtype as the parameters.
342350 precond_dtype: optional str or jnp.dtype, dtype of the preconditioner.
@@ -357,6 +365,7 @@ def kron(
357365 b1 = b1 ,
358366 max_size_triangular = max_size_triangular ,
359367 max_skew_triangular = max_skew_triangular ,
368+ min_ndim_triangular = min_ndim_triangular ,
360369 mu_dtype = mu_dtype ,
361370 precond_dtype = precond_dtype ,
362371 precision = precision ,
@@ -424,7 +433,7 @@ def no_calc(_):
424433 return jax .lax .cond (max_abs > 0 , calc , no_calc , A )
425434
426435
427- def _init_Q_exprs (t , scale , max_size , max_skew , dtype , existing_Q = None ):
436+ def _init_Q_exprs (t , scale , max_size , max_skew , min_ndim_triangular , dtype , existing_Q = None ):
428437 """
429438 For a scalar or tensor `t`, we initialize its preconditioner `Q` and reusable
430439 contraction expressions for updating `Q` and preconditioning gradient.
@@ -476,7 +485,12 @@ def _init_Q_exprs(t, scale, max_size, max_skew, dtype, existing_Q=None):
476485 # used for getting the subscripts for exprP
477486 piece1P , piece2P , piece3P , piece4P = ([], [], "" , "" )
478487 for i , size in enumerate (shape ):
479- if size == 1 or size > max_size or size > max_skew * beta_size :
488+ if (
489+ size == 1
490+ or size > max_size
491+ or size > max_skew * beta_size
492+ or len (shape ) < min_ndim_triangular
493+ ):
480494 # use diagonal matrix as preconditioner for this dim
481495 if existing_Q is None :
482496 Q .append (scale * jnp .ones (size , dtype = dtype ))
0 commit comments