[Grug] Add assigned-token DeepEP MoE dispatch#6251
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7c6fc409f4
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| "ring", # Expert-parallel all-gather + psum-scatter backend. | ||
| "ragged_all_to_all", # Expert-parallel ragged all-to-all backend. | ||
| "deepep", # Expert-parallel DeepEP intranode dispatch/combine backend. | ||
| "assigned_token", # Expert-parallel plain-XLA assigned-token backend. |
There was a problem hiding this comment.
Update stale ragged_all_to_all contract test
When the Grug variant contract tests run, tests/test_grug_variant_contracts.py still does dataclasses.replace(cfg, moe_implementation="ragged_all_to_all") and later asserts that string is in the jaxpr. This diff removes ragged_all_to_all from MoeImplementation, so GrugModelConfig.__post_init__ now calls resolve_moe_implementation() and raises for that test input before exercising the new backend. Please update that test to use the new assigned_token implementation (and expected jaxpr string) or otherwise keep the old alias valid.
Useful? React with 👍 / 👎.
# Conflicts: # lib/levanter/src/levanter/grug/_moe/ep_deepep.py # lib/levanter/src/levanter/grug/_moe/ep_ragged_all_to_all.py
Add an assigned-token Grug MoE EP backend and a DeepEP-backed CUDA path that avoids ring global activation buffers. Include focused correctness tests and an issue-shape benchmark harness; 4-GPU accelerator confirmation shows DeepEP median 1.94 ms versus ring 2.32 ms with dropped-count parity.
Fixes #6215