Skip to content

Add production-safe local Grug MoE backends#5718

Open
dlwh wants to merge 1 commit into
mainfrom
codex/moe-quack-down-backend
Open

Add production-safe local Grug MoE backends#5718
dlwh wants to merge 1 commit into
mainfrom
codex/moe-quack-down-backend

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented May 13, 2026

Add local Grug MoE backends for scatter, Sonic-style XLA dispatch, interleaved W13 layout, and a production-safe custom-VJP down/gather boundary. MoEExpertMlp owns expert weight initialization and W13 layout so model code can select the backend without layout plumbing.

Part of #5328

@dlwh dlwh added the agent-generated Created by automation/agent label May 13, 2026
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 13, 2026

Claude finished @dlwh's task in 6m 22s —— View job


Code review

Reviewed for bugs and CLAUDE.md/AGENTS.md compliance.

One issue flagged inline:

Other things considered and not flagged:

  • MoeLocalImplementation typed as Literal[...] rather than StrEnum — follows the established pattern of the pre-existing MoeImplementation in the same module, so flagging would be inconsistent with current code.
  • Boolean dispatch flags (interleaved_w13, quack_down, split_moe_w13_output(interleaved=...)) — these are localized dispatch keys, not the kind of compound state machine the AGENTS.md rule targets.
  • Optional-dep guard in grug_moe.py:131-156 only catches ModuleNotFoundError; a transitive ImportError of a different subclass (or a partial install where find_spec("cuda") succeeds but cuda.bindings.driver import fails) would bypass the friendly message and bypass the test's skip check at test_grugformer_moe.py:257-279. Low-risk in the typical install path; mentioning here for awareness rather than as a blocking issue.
  • custom_vjp residuals/cotangents arity in quack_interleaved_down_gather_sum, the interleaved/concat W13 layout symmetry, and the dispatch_positions construction in _prepare_moe_dispatch_indices_with_assignment_ids all look consistent across forward/backward.
    · Branch: codex/moe-quack-down-backend

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented May 13, 2026

🤖 Specification for the updated >500 LOC backend PR.

Problem: moe_mlp only exposed the EP ring/ragged implementations on main, and model code had to own expert MLP weight layout. The earlier PR version also tried to make the fastest path depend on CuTe/QuACK kernels, but our GH200 lowering work showed the d-gated ragged TMA path is not production-safe yet. We still need an opt-in local path that captures the useful W13 interleaving and custom-VJP boundary without CUDA-only optional imports.

Approach: lib/levanter/src/levanter/grug/grug_moe.py now accepts ring, ragged_all_to_all, scatter, sonic_xla, sonic_xla_interleaved_w13, and sonic_xla_interleaved_w13_custom_vjp_down. EP choices stay EP-only; local choices route through separate straight-line functions. MoEExpertMlp owns expert weight initialization, sharding, backend selection, and W13 concat/interleaved layout. experiments/grug/moe/model.py delegates expert MLP execution to MoEExpertMlp. The custom-VJP backend lives in lib/levanter/src/levanter/grug/custom_vjp_moe.py and uses JAX ragged_dot plus an explicit VJP for the down/gather boundary; it adds no new package dependencies.

Key code:

@jax.custom_vjp
def custom_vjp_interleaved_down_gather_sum(
    w13_out_interleaved,
    combine_weights,
    w_down,
    token_ids_sort,
    sorted_assignment_ids,
    dispatch_positions,
    group_sizes,
):
    out, _ = _custom_vjp_interleaved_down_gather_sum_forward(
        w13_out_interleaved,
        combine_weights,
        w_down,
        token_ids_sort,
        sorted_assignment_ids,
        dispatch_positions,
        group_sizes,
    )
    return out

The backward saves the dispatch output from forward, replays only the SwiGLU activation pullback and down-projection pullback, and computes the combine-weight gradient as sum(dout_sorted * dispatch_output) before unsorting by assignment id.

Tests: lib/levanter/tests/grug/test_grugformer_moe.py covers scatter vs Sonic-style XLA value/gradient parity, concat vs interleaved W13 value/gradient parity, MoEExpertMlp layout-hiding parity, custom-VJP down/gather value/gradient parity against the interleaved XLA path, and EP abstract-mesh lowering. Local validation ran ./infra/pre-commit.py --all-files --fix and uv run --package marin --group dev pytest -m 'not slow' lib/levanter/tests/grug/test_grugformer_moe.py -q.

Comment on lines +56 to +68
def _gather_sum_reference(
dispatch_output: jax.Array,
dispatch_positions: jax.Array,
combine_weights: jax.Array,
) -> jax.Array:
acc = jnp.zeros((dispatch_positions.shape[0], dispatch_output.shape[1]), dtype=dispatch_output.dtype)
weights = combine_weights.astype(dispatch_output.dtype)
for topk_index in range(dispatch_positions.shape[1]):
gathered = jnp.take(dispatch_output, dispatch_positions[:, topk_index], axis=0)
acc = (acc + (gathered * weights[:, topk_index, None]).astype(dispatch_output.dtype)).astype(
dispatch_output.dtype
)
return acc
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parallel implementation of _gather_sum_reference in grug_moe.py — the body is identical (only the type annotations differ: Float[Array, ...] vs plain jax.Array). Since quack_moe is in the same package and is itself imported lazily from grug_moe, you can import the helper directly without circular-import concerns.

Per root AGENTS.md: "If a suitable implementation exists, use it. Do not create parallel implementations."

Suggested change
def _gather_sum_reference(
dispatch_output: jax.Array,
dispatch_positions: jax.Array,
combine_weights: jax.Array,
) -> jax.Array:
acc = jnp.zeros((dispatch_positions.shape[0], dispatch_output.shape[1]), dtype=dispatch_output.dtype)
weights = combine_weights.astype(dispatch_output.dtype)
for topk_index in range(dispatch_positions.shape[1]):
gathered = jnp.take(dispatch_output, dispatch_positions[:, topk_index], axis=0)
acc = (acc + (gathered * weights[:, topk_index, None]).astype(dispatch_output.dtype)).astype(
dispatch_output.dtype
)
return acc
from levanter.grug.grug_moe import _gather_sum_reference

@dlwh dlwh force-pushed the codex/moe-quack-down-backend branch 5 times, most recently from 4a55948 to 5ec6613 Compare May 14, 2026 03:17
@dlwh dlwh force-pushed the codex/moe-quack-down-backend branch from 5ec6613 to 134d7b4 Compare May 15, 2026 08:09
@dlwh dlwh changed the title Add optional QuACK MoE local backend Add production-safe local Grug MoE backends May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant