Skip to content

Commit d2bef3c

Browse files
authored
fix distributed all_to_sharded bias shard axis from -2 to -1 (#2987)
1 parent 3fe7794 commit d2bef3c

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

python/mlx/nn/layers/distributed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def _all_to_sharded(segments):
8686
representation becomes a sharded representation."""
8787

8888
def _shard_fn(path, weight):
89+
if path.endswith("bias"):
90+
return -1, segments
8991
return max(weight.ndim - 2, 0), segments
9092

9193
return _shard_fn

0 commit comments

Comments
 (0)