Add canonical Grug MoE module and template variant#3046
Conversation
|
@claude ptal |
|
Claude finished @dlwh's task in 2m 57s —— View job 🤖
Review: Refactor Grug MoE core and activation handlingOverall this is a clean, well-structured MoE implementation. The separation of the functional Agreeing with self-review items1. jaxtyping annotations on +1 to the self-review. The existing 2. Comment wording (line 260) Agree with the suggested edit: 3. Agree. Currently batch_spec = x.sharding.spec[:1] + (P(None),) * (x.ndim - 1)or simply pulling the first partition from Additional observations4. Hardcoded activation in The functional 5. Activation dispatch: if callable(activation):
activation_fn = activation
else:
activation_fn = activation.to_jax_fn()Since if isinstance(activation, ActivationFunctionEnum):
activation_fn = activation.to_jax_fn()
else:
activation_fn = activation # already a CallableThis is more explicit about what the union type 6. if mesh is None:
mesh = get_abstract_mesh() # line 343-344
...
if mesh is None or mesh.empty: # line 364If
Minor nit but worth clarifying the contract. 7. Duplicated routing logic between Both functions independently compute 8. EP path group_sizes = jnp.bincount(expert_local, weights=valid_i32, length=local_experts).astype(jnp.int32)
9. EP
10. Test coverage Test coverage is solid: EP mesh, dense mesh, abstract mesh lowering, loss, shared expert disable, and functional activation enum vs callable equivalence. Two suggestions:
11. Clean and correct. SummaryGood refactor overall. The main actionable items are:
|
|
🤖 Followed up on the extra Claude review points as part of commit
I did not make |
|
@claude ptal |
|
Claude finished @dlwh's task in 2m 27s —— View job 🤖
Review: Grug MoE Module, Template Variant, and ContractsThe core MoE implementation in I have a few items below, ranging from structural concerns to minor nits. 1. Near-total duplication between
|
|
🤖 Grug variant diff report
Artifact fallback: Download report bundle |
|
@claude ptal |
|
Claude finished @dlwh's task in 3m 7s —— View job 🤖
Review: Grug MoE — Post-Refactor PassThe refactoring since the last review is a substantial structural improvement. The core change — stripping all module/model classes out of Key changes I verified:
One item: 1. def _moe_mlp_ep_ring_local(
...
_report_capacity_overflow: bool, # never read
) -> tuple[Float[Array, "TL D"], Int[Array, ""]]:
...
dropped_local = jnp.maximum(local_count - local_capacity, 0)
...
dropped_total = jax.lax.psum(dropped_local, ("data", "expert"))
return out_local, dropped_totalThe function always computes and returns Otherwise this looks good. The functional kernel is well-tested (EP mesh, dense mesh, abstract mesh lowering, shape validation, activation dispatch, dropped count reporting), the template variant is self-contained per the grug copy-first philosophy, and the contract tests verify cross-variant consistency. |
|
Claude finished @dlwh's task in 2m 51s —— View job
ReviewAdds a functional MoE kernel ( Specification check: No linked spec or Findings:
No other correctness issues found. Test coverage is solid -- the variant contract tests verify abstract mesh lowering and train-step shape parity across discovered variants, and |
## Context Grug code is deliberately template-y/copypasta, which makes reviewing diffs a bit tricky. As a quick workaround, we made a tool to diff the nearest neighbor grug variant See a slightly faked version of the workflow in action here #3046 (comment) ## Summary - add a one-page HTML visual directory diff tool for Grug/template-style code copies - document usage in `experiments/grug/README.md` - add CI helper to detect newly added Grug variants, find closest existing variant, and generate diff reports - add a PR workflow that publishes rendered reports to `gh-pages` and comments links on the PR ## Testing - `uvx ruff check scripts/grug_dir_diff.py scripts/grug_variant_diff_ci.py tests/test_grug_dir_diff.py tests/test_grug_variant_diff_ci.py` - `uv run --with pytest --with pytest-timeout python -m pytest -q tests/test_grug_dir_diff.py tests/test_grug_variant_diff_ci.py`
## Summary This PR splits Grug variant contract work into its own change so it can land before MoE implementation changes ( #3046 ) ## Included - Add `tests/test_grug_variant_contracts.py` as the contract suite for Grug variants. - Remove legacy `tests/test_grug_base_template.py`. - Update `experiments/grug/base` to satisfy the contract surface: - `Transformer.next_token_loss(...)` - `debug_mesh_and_token_pspec(num_devices)` - `initial_state(...)` in `train.py` - Update Grug docs to point to the new contract test path and document the contract. ## Validation - `uv run --with pytest-timeout pytest tests/test_grug_variant_contracts.py -q` - `uv run python infra/pre-commit.py --all-files`
|
@claude one last time? |
|
Claude finished @dlwh's task in 2m 1s —— View job 🤖
Review: Grug MoE Module and Template Variant (Final Pass)Adds a functional MoE kernel ( No correctness issues found. This looks good to merge. Changes verified:
Test coverage: Dense mesh, EP mesh, abstract mesh lowering, activation enum vs callable equivalence, and capacity overflow reporting. Sufficient for the kernel surface. |
| return mesh, P(("data", "expert"), None) | ||
|
|
||
|
|
||
| GrugMoeModelConfig = GrugModelConfig |
There was a problem hiding this comment.
I can't think of a scenario yet where I'll be importing under the name GrugMoeModelConfig.
| levanter.tracker.log({"throughput/loading_time": iterator.this_load_time}, step=step) | ||
| router_metrics = {key: value for key, value in metrics.items() if key.startswith("train/router/")} | ||
| if router_metrics: | ||
| levanter.tracker.log(router_metrics, step=step) |
There was a problem hiding this comment.
Like where the logging for MoE metrics lives now!
| combine_weights = jax.nn.softmax(topk_logits, axis=-1).astype(x.dtype) | ||
| router_stats = _routing_stats_from_selected_experts(selected_experts, num_experts=self.cfg.num_experts) | ||
|
|
||
| routed_flat = moe_mlp( |
There was a problem hiding this comment.
really like having this be the interface level. At some point I think we can extend this to support non gated MLPs.
|
|
||
|
|
||
| @named_call | ||
| def shared_dense_mlp( |
There was a problem hiding this comment.
conceptually it might be simpler if MLP() is its own eqx class with up_gate/down, and then the MoEMLP has an attribute called shared_mlp of type MLP. This shared variable for gate/up may impact how Muon Orthogonalization behaves on the matrix- here it is super easy for the user to modify, but for the moe_mlp() function in levanter/grug we might want to think about how to enable specification of matrix slicing for orthogonalization or something of that sort.
|
approving, no reason to block merge on these comments. |
## Context Grug code is deliberately template-y/copypasta, which makes reviewing diffs a bit tricky. As a quick workaround, we made a tool to diff the nearest neighbor grug variant See a slightly faked version of the workflow in action here #3046 (comment) ## Summary - add a one-page HTML visual directory diff tool for Grug/template-style code copies - document usage in `experiments/grug/README.md` - add CI helper to detect newly added Grug variants, find closest existing variant, and generate diff reports - add a PR workflow that publishes rendered reports to `gh-pages` and comments links on the PR ## Testing - `uvx ruff check scripts/grug_dir_diff.py scripts/grug_variant_diff_ci.py tests/test_grug_dir_diff.py tests/test_grug_variant_diff_ci.py` - `uv run --with pytest --with pytest-timeout python -m pytest -q tests/test_grug_dir_diff.py tests/test_grug_variant_diff_ci.py`
## Summary This PR splits Grug variant contract work into its own change so it can land before MoE implementation changes ( #3046 ) ## Included - Add `tests/test_grug_variant_contracts.py` as the contract suite for Grug variants. - Remove legacy `tests/test_grug_base_template.py`. - Update `experiments/grug/base` to satisfy the contract surface: - `Transformer.next_token_loss(...)` - `debug_mesh_and_token_pspec(num_devices)` - `initial_state(...)` in `train.py` - Update Grug docs to point to the new contract test path and document the contract. ## Validation - `uv run --with pytest-timeout pytest tests/test_grug_variant_contracts.py -q` - `uv run python infra/pre-commit.py --all-files`
## Summary - add a canonical compact Grug MoE implementation in `lib/levanter/src/levanter/grug/grug_moe.py` - keep routing (`router matmul + top-k + softmax`) inline in `MoEMLP.__call__` - make `moe_mlp` the reusable dispatch/permute/unpermute (+EP) kernel over precomputed routing selections/weights - add template-first Grug MoE experiment surface under `experiments/grug/moe/` (`model.py`, `train.py`, `launch.py`, `__init__.py`) aligned with `experiments/grug/base` - update activation wiring to use `ActivationFunctionEnum` consistently and extend `levanter.utils.activation` with `relu2` ## Not In This PR - Grug variant contract/test refactors were split and landed separately in [#3169](#3169). ## Validation - `uv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug/test_grugformer_moe.py -q` - `uv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug -q`
Summary
lib/levanter/src/levanter/grug/grug_moe.pyrouter matmul + top-k + softmax) inline inMoEMLP.__call__moe_mlpthe reusable dispatch/permute/unpermute (+EP) kernel over precomputed routing selections/weightsexperiments/grug/moe/(model.py,train.py,launch.py,__init__.py) aligned withexperiments/grug/baseActivationFunctionEnumconsistently and extendlevanter.utils.activationwithrelu2Not In This PR
Validation
uv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug/test_grugformer_moe.py -quv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug -q