Skip to content

Commit 1b76bbc

Browse files
github-actions[bot]Helw150claude
committed
Refactor AdamH scale-invariant update to use vmap for higher-rank tensors
Co-authored-by: William Held <Helw150@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fa54eb7 commit 1b76bbc

1 file changed

Lines changed: 11 additions & 10 deletions

File tree

experiments/grug/moe/adamh.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,20 @@ def update_fn(updates, state, params):
4848
)
4949
mu = otu.tree_cast(mu, mu_dtype)
5050

51+
def _scale_invariant_2d(p, u):
52+
"""Core update for a 2-D (matrix) parameter."""
53+
p_norm = jnp.linalg.norm(p)
54+
u_norm = jnp.linalg.norm(u)
55+
new_p = p - learning_rate * u * p_norm / jnp.maximum(u_norm, 1e-10)
56+
return new_p / jnp.linalg.norm(new_p) * p_norm - p
57+
5158
def scale_invariant_update(p, u):
5259
if p is None:
5360
return None
54-
if p.ndim == 2:
55-
new_p = p - learning_rate * u * jnp.linalg.norm(p) / jnp.maximum(jnp.linalg.norm(u), 1e-10)
56-
return new_p / jnp.linalg.norm(new_p) * jnp.linalg.norm(p) - p
57-
else:
58-
axes = tuple(range(1, p.ndim))
59-
p_norm = jnp.sqrt(jnp.sum(jnp.square(p), axis=axes, keepdims=True))
60-
u_norm = jnp.sqrt(jnp.sum(jnp.square(u), axis=axes, keepdims=True))
61-
new_p = p - learning_rate * u * p_norm / jnp.maximum(u_norm, 1e-10)
62-
new_p_norm = jnp.sqrt(jnp.sum(jnp.square(new_p), axis=axes, keepdims=True))
63-
return new_p / jnp.maximum(new_p_norm, 1e-10) * p_norm - p
61+
if p.ndim <= 2:
62+
return _scale_invariant_2d(p, u)
63+
# For higher-rank tensors, vmap the 2-D logic over the leading axis.
64+
return jax.vmap(_scale_invariant_2d)(p, u)
6465

6566
adamh_updates = jax.tree_util.tree_map(
6667
scale_invariant_update,

0 commit comments

Comments
 (0)