Default v5e-8 mesh not efficient? #33518
Unanswered
sssshhhhhh
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
On v5e-8 2x4 topology I get these device coords which are sequential already so the mesh ids are directly
jax/jax/_src/mesh_utils.py
Line 72 in b08a105
Which made me ask why 1(1,0)/2(0, 1); or 3(1,1)/7(1,3) are adjacent. So I implemented a bidirectional all gather
(this img basically https://jax-ml.github.io/scaling-book/assets/img/all-gather.gif)
This is double the roofline but if I change the mesh to this it's a lot faster
Am I calculating the neighbors wrong? Is the mesh supposed to be like this?
Beta Was this translation helpful? Give feedback.
All reactions