Skip to content

Commit 202715d

Browse files
committed
reshard back for 3d
1 parent bd72716 commit 202715d

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

lib/levanter/src/levanter/optim/grugmuon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def transform_array(x):
135135
updated = jax.vmap(
136136
lambda m: _newtonschulz_core(m, steps=steps, eps=muon_eps, coefficient_type=coefficient_type)
137137
)(x)
138+
updated = reshard(updated, current_spec)
138139
# Layout per slice is (fan_in, fan_out)
139140
fan_in, fan_out = updated.shape[1], updated.shape[2]
140141
else:

0 commit comments

Comments
 (0)