Skip to content

Commit 699260e

Browse files
committed
fix sharding
1 parent ae8fd6b commit 699260e

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

lib/levanter/src/levanter/optim/grugmuon.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,15 @@ def update_fn(updates, state, params=None):
126126
def transform_array(x):
127127
if not hasattr(x, "ndim") or x.ndim not in (2, 3):
128128
return x
129-
if x.ndim == 3:
130-
from jax.sharding import PartitionSpec as P, reshard
129+
from jax.sharding import PartitionSpec as P, reshard
131130

131+
original_spec = jax.typeof(x).sharding.spec
132+
if x.ndim == 3:
132133
# Keep the first dim's existing sharding, replicate only the last 2 dims for Newton-Schulz
133-
current_spec = x.sharding.spec
134-
x = reshard(x, P(current_spec[0], None, None))
134+
x = reshard(x, P(original_spec[0], None, None))
135135
updated = jax.vmap(
136136
lambda m: _newtonschulz_core(m, steps=steps, eps=muon_eps, coefficient_type=coefficient_type)
137137
)(x)
138-
updated = reshard(updated, current_spec)
139138
# Layout per slice is (fan_in, fan_out)
140139
fan_in, fan_out = updated.shape[1], updated.shape[2]
141140
else:
@@ -147,6 +146,7 @@ def transform_array(x):
147146
else:
148147
scale = 0.2 * jnp.sqrt(jnp.maximum(fan_in, fan_out))
149148
updated *= scale
149+
updated = reshard(updated, original_spec)
150150
return updated
151151

152152
updates = jax.tree.map(transform_array, updates)

0 commit comments

Comments
 (0)