@@ -126,16 +126,15 @@ def update_fn(updates, state, params=None):
126126 def transform_array (x ):
127127 if not hasattr (x , "ndim" ) or x .ndim not in (2 , 3 ):
128128 return x
129- if x .ndim == 3 :
130- from jax .sharding import PartitionSpec as P , reshard
129+ from jax .sharding import PartitionSpec as P , reshard
131130
131+ original_spec = jax .typeof (x ).sharding .spec
132+ if x .ndim == 3 :
132133 # Keep the first dim's existing sharding, replicate only the last 2 dims for Newton-Schulz
133- current_spec = x .sharding .spec
134- x = reshard (x , P (current_spec [0 ], None , None ))
134+ x = reshard (x , P (original_spec [0 ], None , None ))
135135 updated = jax .vmap (
136136 lambda m : _newtonschulz_core (m , steps = steps , eps = muon_eps , coefficient_type = coefficient_type )
137137 )(x )
138- updated = reshard (updated , current_spec )
139138 # Layout per slice is (fan_in, fan_out)
140139 fan_in , fan_out = updated .shape [1 ], updated .shape [2 ]
141140 else :
@@ -147,6 +146,7 @@ def transform_array(x):
147146 else :
148147 scale = 0.2 * jnp .sqrt (jnp .maximum (fan_in , fan_out ))
149148 updated *= scale
149+ updated = reshard (updated , original_spec )
150150 return updated
151151
152152 updates = jax .tree .map (transform_array , updates )
0 commit comments