@@ -48,19 +48,20 @@ def update_fn(updates, state, params):
4848 )
4949 mu = otu .tree_cast (mu , mu_dtype )
5050
51+ def _scale_invariant_2d (p , u ):
52+ """Core update for a 2-D (matrix) parameter."""
53+ p_norm = jnp .linalg .norm (p )
54+ u_norm = jnp .linalg .norm (u )
55+ new_p = p - learning_rate * u * p_norm / jnp .maximum (u_norm , 1e-10 )
56+ return new_p / jnp .linalg .norm (new_p ) * p_norm - p
57+
5158 def scale_invariant_update (p , u ):
5259 if p is None :
5360 return None
54- if p .ndim == 2 :
55- new_p = p - learning_rate * u * jnp .linalg .norm (p ) / jnp .maximum (jnp .linalg .norm (u ), 1e-10 )
56- return new_p / jnp .linalg .norm (new_p ) * jnp .linalg .norm (p ) - p
57- else :
58- axes = tuple (range (1 , p .ndim ))
59- p_norm = jnp .sqrt (jnp .sum (jnp .square (p ), axis = axes , keepdims = True ))
60- u_norm = jnp .sqrt (jnp .sum (jnp .square (u ), axis = axes , keepdims = True ))
61- new_p = p - learning_rate * u * p_norm / jnp .maximum (u_norm , 1e-10 )
62- new_p_norm = jnp .sqrt (jnp .sum (jnp .square (new_p ), axis = axes , keepdims = True ))
63- return new_p / jnp .maximum (new_p_norm , 1e-10 ) * p_norm - p
61+ if p .ndim <= 2 :
62+ return _scale_invariant_2d (p , u )
63+ # For higher-rank tensors, vmap the 2-D logic over the leading axis.
64+ return jax .vmap (_scale_invariant_2d )(p , u )
6465
6566 adamh_updates = jax .tree_util .tree_map (
6667 scale_invariant_update ,
0 commit comments