Add production-safe local Grug MoE backends#5718
Conversation
|
Claude finished @dlwh's task in 6m 22s —— View job Code reviewReviewed for bugs and CLAUDE.md/AGENTS.md compliance. One issue flagged inline:
Other things considered and not flagged:
|
|
🤖 Specification for the updated >500 LOC backend PR. Problem: Approach: Key code: @jax.custom_vjp
def custom_vjp_interleaved_down_gather_sum(
w13_out_interleaved,
combine_weights,
w_down,
token_ids_sort,
sorted_assignment_ids,
dispatch_positions,
group_sizes,
):
out, _ = _custom_vjp_interleaved_down_gather_sum_forward(
w13_out_interleaved,
combine_weights,
w_down,
token_ids_sort,
sorted_assignment_ids,
dispatch_positions,
group_sizes,
)
return outThe backward saves the dispatch output from forward, replays only the SwiGLU activation pullback and down-projection pullback, and computes the combine-weight gradient as Tests: |
| def _gather_sum_reference( | ||
| dispatch_output: jax.Array, | ||
| dispatch_positions: jax.Array, | ||
| combine_weights: jax.Array, | ||
| ) -> jax.Array: | ||
| acc = jnp.zeros((dispatch_positions.shape[0], dispatch_output.shape[1]), dtype=dispatch_output.dtype) | ||
| weights = combine_weights.astype(dispatch_output.dtype) | ||
| for topk_index in range(dispatch_positions.shape[1]): | ||
| gathered = jnp.take(dispatch_output, dispatch_positions[:, topk_index], axis=0) | ||
| acc = (acc + (gathered * weights[:, topk_index, None]).astype(dispatch_output.dtype)).astype( | ||
| dispatch_output.dtype | ||
| ) | ||
| return acc |
There was a problem hiding this comment.
Parallel implementation of _gather_sum_reference in grug_moe.py — the body is identical (only the type annotations differ: Float[Array, ...] vs plain jax.Array). Since quack_moe is in the same package and is itself imported lazily from grug_moe, you can import the helper directly without circular-import concerns.
Per root AGENTS.md: "If a suitable implementation exists, use it. Do not create parallel implementations."
| def _gather_sum_reference( | |
| dispatch_output: jax.Array, | |
| dispatch_positions: jax.Array, | |
| combine_weights: jax.Array, | |
| ) -> jax.Array: | |
| acc = jnp.zeros((dispatch_positions.shape[0], dispatch_output.shape[1]), dtype=dispatch_output.dtype) | |
| weights = combine_weights.astype(dispatch_output.dtype) | |
| for topk_index in range(dispatch_positions.shape[1]): | |
| gathered = jnp.take(dispatch_output, dispatch_positions[:, topk_index], axis=0) | |
| acc = (acc + (gathered * weights[:, topk_index, None]).astype(dispatch_output.dtype)).astype( | |
| dispatch_output.dtype | |
| ) | |
| return acc | |
| from levanter.grug.grug_moe import _gather_sum_reference |
4a55948 to
5ec6613
Compare
5ec6613 to
134d7b4
Compare
Add local Grug MoE backends for scatter, Sonic-style XLA dispatch, interleaved W13 layout, and a production-safe custom-VJP down/gather boundary. MoEExpertMlp owns expert weight initialization and W13 layout so model code can select the backend without layout plumbing.
Part of #5328