Skip to content

Add canonical Grug MoE module and template variant#3046

Merged
dlwh merged 23 commits intomainfrom
codex/grug-moe-ep-sensible-defaults
Mar 3, 2026
Merged

Add canonical Grug MoE module and template variant#3046
dlwh merged 23 commits intomainfrom
codex/grug-moe-ep-sensible-defaults

Conversation

@dlwh
Copy link
Copy Markdown
Member

@dlwh dlwh commented Feb 26, 2026

Summary

  • add a canonical compact Grug MoE implementation in lib/levanter/src/levanter/grug/grug_moe.py
  • keep routing (router matmul + top-k + softmax) inline in MoEMLP.__call__
  • make moe_mlp the reusable dispatch/permute/unpermute (+EP) kernel over precomputed routing selections/weights
  • add template-first Grug MoE experiment surface under experiments/grug/moe/ (model.py, train.py, launch.py, __init__.py) aligned with experiments/grug/base
  • update activation wiring to use ActivationFunctionEnum consistently and extend levanter.utils.activation with relu2

Not In This PR

  • Grug variant contract/test refactors were split and landed separately in #3169.

Validation

  • uv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug/test_grugformer_moe.py -q
  • uv run --python 3.11 --package levanter --group test pytest lib/levanter/tests/grug -q

@dlwh dlwh added the agent-generated Created by automation/agent label Feb 26, 2026
Comment thread lib/levanter/src/levanter/grug/grug_moe.py Outdated
Comment thread lib/levanter/src/levanter/grug/grug_moe.py Outdated
Comment thread lib/levanter/src/levanter/grug/grug_moe.py Outdated
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Feb 26, 2026

@claude ptal

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Feb 26, 2026

Claude finished @dlwh's task in 2m 57s —— View job


🤖

  • Read changed files and gather context
  • Review grug_moe.py (new, +603 lines)
  • Review activation.py (modified, +29 lines)
  • Review test_grugformer_moe.py (new, +243 lines)
  • Address existing review comments from dlwh
  • Post comprehensive review

Review: Refactor Grug MoE core and activation handling

Overall this is a clean, well-structured MoE implementation. The separation of the functional moe_mlp core from the MoEMLP equinox module is a good design, and the ring EP path is clearly documented with issue references. A few items below.


Agreeing with self-review items

1. jaxtyping annotations on CausalSelfAttention fields (line 116)

+1 to the self-review. The existing model.py uses jaxtyping on its __call__ signatures but not on module fields, so this MoE file is at least consistent. But given the more complex weight shapes here (especially the (E, D, I2) fused gate/up), explicit shape annotations on w_q/w_k/w_v/w_o and on the MoEMLP fields (moe_router, moe_w13, moe_w2, shared_w13, shared_w2) would help readers reason about sharding specs.

2. Comment wording (line 260)

Agree with the suggested edit: "# NB: this means we receive all tokens on the DP axis, best for low EP" adds helpful context.

3. batch_spec derivable from x.sharding (line 375)

Agree. Currently _batch_spec(mesh) reconstructs the spec from mesh axis names, but x already carries its sharding. Deriving from x.sharding would be more robust and self-documenting, and would remove the batch_spec parameter concern entirely. Something like:

batch_spec = x.sharding.spec[:1] + (P(None),) * (x.ndim - 1)

or simply pulling the first partition from x.sharding.spec.


Additional observations

4. Hardcoded activation in MoEMLP.__call__ (lines 474, 488)

The functional moe_mlp accepts a configurable MoeActivation, but MoEMLP.__call__ hardcodes ActivationFunctionEnum.silu for the routed path and jax.nn.silu for the shared dense path. If the intent is for MoEMLP to always use silu, this is fine — but it means the MoeActivation type alias and the activation parameter on moe_mlp are only exercisable by direct callers of the functional API. Consider making MoEMLP store the activation choice (as a static field) so model configs can control it, especially since relu2 was just added to the enum.

5. Activation dispatch: callable() check (line 346)

if callable(activation):
    activation_fn = activation
else:
    activation_fn = activation.to_jax_fn()

Since ActivationFunctionEnum(str, enum.Enum) instances aren't callable, this works correctly. However, the intent reads more clearly as isinstance(activation, ActivationFunctionEnum) with the branches swapped:

if isinstance(activation, ActivationFunctionEnum):
    activation_fn = activation.to_jax_fn()
else:
    activation_fn = activation  # already a Callable

This is more explicit about what the union type MoeActivation actually holds and avoids a subtle "what if someone passes a class type?" edge.

6. mesh is None on line 364 is dead after line 343–344

if mesh is None:
    mesh = get_abstract_mesh()  # line 343-344
...
if mesh is None or mesh.empty:  # line 364

If get_abstract_mesh() can return None, then mesh.empty on line 364 would NPE. If it can't return None, then the mesh is None branch is dead code. Suggest either:

  • Guard: if mesh is None or (hasattr(mesh, 'empty') and mesh.empty):
  • Or just if mesh is None or mesh.empty: with a comment that get_abstract_mesh() may return None

Minor nit but worth clarifying the contract.

7. Duplicated routing logic between _moe_mlp_local and _moe_mlp_ep_ring_local

Both functions independently compute router_logits → top_k → softmax. The EP path then does its own sort/dispatch because it needs global tokens, so the duplication is understandable. But the first three lines (einsum + top_k + softmax) are identical — factoring out a _route_tokens(x_flat, moe_router, num_experts_per_token, dtype) helper would reduce the surface area for divergence as routing evolves.

8. EP path jnp.bincount with weights=valid_i32 (line 307)

group_sizes = jnp.bincount(expert_local, weights=valid_i32, length=local_experts).astype(jnp.int32)

jnp.bincount's weights parameter typically expects floats. This works but an explicit .astype(jnp.float32) on valid_i32 before passing would make the intent clearer and avoid any future JAX strictness changes.

9. EP num_experts divisibility check is redundant

moe_mlp (line 379–380) validates num_experts % expert_axis_size != 0, and _moe_mlp_ep_ring_local (line 279–282) validates num_experts % local_experts != 0. Since local_experts = moe_w13_local.shape[0] and the shard_map divides by expert_axis_size, these are essentially the same check. Not a bug, just belt-and-suspenders.

10. Test coverage

Test coverage is solid: EP mesh, dense mesh, abstract mesh lowering, loss, shared expert disable, and functional activation enum vs callable equivalence. Two suggestions:

  • A test exercising relu2 (or another non-silu activation) through the enum path would validate the new activation.py additions end-to-end.
  • The _moe_mlp_ep_ring_local capacity overflow path (where tokens exceed local_capacity) is not explicitly tested. A test with a very small capacity_factor could cover this.

11. activation.py changes

Clean and correct. _relu2_jax matches the haliax relu_squared implementation, and both mapping dicts are updated consistently.


Summary

Good refactor overall. The main actionable items are:

  1. Add jaxtyping annotations (per self-review)
  2. Derive batch_spec from x.sharding (per self-review)
  3. Consider making MoEMLP activation configurable rather than hardcoded
  4. Use isinstance instead of callable for activation dispatch clarity
  5. Factor out the shared routing preamble if you want to reduce duplication

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Feb 26, 2026

🤖 Followed up on the extra Claude review points as part of commit 01757f4c5:

  • switched MoE activation dispatch to explicit isinstance(activation, ActivationFunctionEnum) handling
  • adjusted EP bincount weighting to use float weights for clarity/stability
  • kept the new batch_spec derivation from x.sharding.spec[0] with mesh fallback

I did not make MoEMLP activation configurable yet; it remains intentionally hardcoded to SiLU in the module wrapper, while moe_mlp(...) stays configurable for direct functional callers.

@dlwh dlwh changed the title Refactor Grug MoE core and activation handling Add canonical Grug MoE module, template variant, and contracts Feb 28, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Feb 28, 2026

@claude ptal

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Feb 28, 2026

Claude finished @dlwh's task in 2m 27s —— View job


🤖

  • Read changed files and gather context
  • Review grug_moe.py (core MoE implementation)
  • Review experiments/grug/moe/ (template variant)
  • Review test_variant_contracts.py and test_grugformer_moe.py
  • Review remaining changes (activation.py, test_grug_base_template.py, README.md)
  • Check follow-up commit addressing previous review points
  • Post comprehensive review

Review: Grug MoE Module, Template Variant, and Contracts

The core MoE implementation in grug_moe.py is solid — the functional moe_mlp / equinox MoEMLP split is clean, the EP ring path is well-documented with issue references, and the follow-up commit (01757f4c5) correctly addressed the isinstance dispatch, float bincount weights, and batch_spec derivation from x.sharding. The activation.py changes are correct, the variant contract test pattern is a nice addition, and test coverage is strong.

I have a few items below, ranging from structural concerns to minor nits.


1. Near-total duplication between grug_moe.py and experiments/grug/moe/model.py

model.py copies CausalSelfAttention, RMSNorm, Block, Transformer, GrugModelConfig, _mesh_has_axis, _mesh_axis_size, _batch_spec, _init_weight, and _shared_dense_mlp verbatim (or near-verbatim) from grug_moe.py, then only imports moe_mlp and MoeActivation from it. That's ~250 lines of duplicated code.

I understand the template philosophy ("variants iterate independently"), but right now the only MoE-specific code in model.py is the MoEMLP class (~70 lines) and debug_mesh_and_token_pspec (~20 lines). Everything else — attention, norms, the transformer shell, config — is identical. If these diverge independently, the shared correctness surface doubles; if they don't diverge, the duplication is pure maintenance cost.

Consider either:

  • Importing the non-MoE building blocks from grug_moe.py (or from a shared grug_blocks module), or
  • At minimum, adding a comment at the top of model.py noting which classes are intentional copies vs which are expected to stay in sync with grug_moe.py.

2. Weight naming divergence creates checkpoint incompatibility

grug_moe.py's MoEMLP uses field names moe_router, moe_w13, moe_w2, shared_w13, shared_w2, while model.py's MoEMLP uses router, w_up_gate, w_down, shared_w_up_gate, shared_w_down. Since Equinox serializes by field name, checkpoints from one implementation can't load into the other. If these are truly meant to be two distinct checkpointing surfaces, that's fine but worth documenting. If they should be interchangeable, the field names need to align.

3. Loss function naming inconsistency across variants

The base variant uses compute_next_token_loss while the MoE variant uses next_token_loss. The variant contract test (test_variant_contracts.py:46-51) works around this with _infer_loss_fn_name and hasattr checks, but this means:

  • Variants aren't drop-in replaceable
  • The train loop in moe/train.py is tightly coupled to the MoE naming

Standardizing on one name (preferably next_token_loss to match the Levanter library convention in grug_moe.py) across both variants would clean this up. The _infer_loss_fn_name heuristic could then become a simple assertion.

4. moe/train.py is 482 lines with only 4 lines different from base/train.py

The diff between the two train files is:

< from experiments.grug.base.model import GrugModelConfig, Transformer
> from experiments.grug.moe.model import GrugModelConfig, Transformer
< per_pos_loss = model.compute_next_token_loss(
> per_pos_loss = model.next_token_loss(
< glu=False,
> glu=True, num_experts=..., num_shared_experts=..., num_experts_per_tok=...
< return compute_params.compute_next_token_loss(
> return compute_params.next_token_loss(

If the loss function name were standardized (point 3), the train loop could be shared entirely with just the model import and FLOP params as variant-specific config. This would avoid 480 lines of copy-paste that will need to be kept in sync for every train loop improvement.

5. EP capacity overflow is silent

In _moe_mlp_ep_ring_local (grug_moe.py:293), when local_count > local_capacity, excess tokens are silently dropped via jnp.nonzero(..., size=local_capacity, fill_value=0). This is fine for JIT-friendliness but could silently degrade model quality under skewed routing. Consider adding a debug counter (even if it's only enabled under a flag or logged periodically) so users can detect when this is happening.

6. Shared expert weights are fully replicated

In both MoEMLP.init implementations, shared_w13 and shared_w2 use P(None, None) — fully replicated. For the trial config this is fine (shared_expert_intermediate_dim=1792), but at larger scales this becomes significant memory overhead. The base variant's dense MLP shards along ("data", "model") / ("model", "data") axes. Consider sharding the shared expert weights similarly, or adding a comment noting the replication is intentional and why.

7. Minor: AssertionError typos in test_variant_contracts.py

Lines 32 and 51 raise AssertionError — this works because Python doesn't validate exception class names at parse time, but it's misspelled (should be AssertionErrorAssertionError). Wait, actually looking again — Python has AssertionError... no, it's AssertionError. Let me re-check: Python's built-in is AssertionError.

Actually, I checked again — the correct Python exception is AssertionError. Looking at lines 32 and 51:

raise AssertionError(f"No grug variants with model.py found under {grug_dir}")
raise AssertionError("Transformer variant must define either compute_next_token_loss or next_token_loss")

Hmm, AssertionError — that is indeed the correct spelling. My mistake, this is fine.

8. Minor: _shared_dense_mlp signature divergence

In grug_moe.py, _shared_dense_mlp takes activation_fn: Callable (pre-resolved). In model.py, it takes activation: MoeActivation (resolves internally). This means the model.py version does its own enum-to-callable dispatch redundantly. Minor, but it's another surface where the two files can drift apart.


Summary

The MoE core (grug_moe.py) is well-implemented and well-tested. The main structural concern is the volume of near-verbatim duplication between the library module and the experiment template — the two model.py + train.py files together add ~850 lines of copy-paste for ~90 lines of MoE-specific logic. If the template philosophy is firm, I'd suggest at least standardizing the loss function name and weight field names across variants to reduce the divergence footprint.

Actionable items by priority:

  1. Standardize loss function naming (next_token_loss everywhere) — low effort, removes a structural divergence and simplifies the contract test
  2. Align MoEMLP field names between grug_moe.py and model.py if checkpoint interop is desired
  3. Consider sharing the train loop with variant-specific config rather than full copy
  4. Consider sharing non-MoE building blocks from grug_moe.py into model.py
  5. Add capacity overflow observability to the EP path
  6. Document shared expert replication decision

View job

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Feb 28, 2026

🤖 Grug variant diff report

New Variant Closest Existing Variant Distance Score Diff
moe base 382 Open

Artifact fallback: Download report bundle

@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 1, 2026

@claude ptal

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 1, 2026

Claude finished @dlwh's task in 3m 7s —— View job


🤖

  • Read changed files and gather context
  • Review latest commits since last review (Feb 28)
  • Review grug_moe.py (core MoE implementation)
  • Review experiments/grug/moe/ (template variant)
  • Review test_variant_contracts.py and test_grugformer_moe.py
  • Review remaining changes
  • Post comprehensive review

Review: Grug MoE — Post-Refactor Pass

The refactoring since the last review is a substantial structural improvement. The core change — stripping all module/model classes out of grug_moe.py and making it purely functional kernels — is the right call. It cleanly resolves the prior review's main concern (library ↔ template duplication) by having grug_moe.py provide only moe_mlp and MoeActivation, with all model wiring living in the template-first experiments/grug/moe/model.py.

Key changes I verified:

  • Routing (router matmul → top_k → softmax) moved out of moe_mlp into the caller (MoEMLP.__call__), so the kernel only handles dispatch/permute/unpermute from precomputed selected_experts/combine_weights. Cleaner separation, and the shard_map specs are updated correctly.
  • Loss function standardized to next_token_loss across both base and moe variants. base/train.py and base/model.py renamed from compute_next_token_loss.
  • Capacity overflow now returns a dropped assignment count (report_capacity_overflow=True) instead of debug logging. The EP psum across ("data", "expert") produces a replicated global count matching the P() out_spec.
  • Variant contract tests auto-discover variants and exercise both abstract-mesh lowering and train-step parity. The _small_model_config helper correctly filters MoE-specific fields for base.
  • PR review guide (github-pr-review.md) documents that grug variant duplication is intentional.

One item:

1. _report_capacity_overflow parameter is dead code in _moe_mlp_ep_ring_local

def _moe_mlp_ep_ring_local(
    ...
    _report_capacity_overflow: bool,  # never read
) -> tuple[Float[Array, "TL D"], Int[Array, ""]]:
    ...
    dropped_local = jnp.maximum(local_count - local_capacity, 0)
    ...
    dropped_total = jax.lax.psum(dropped_local, ("data", "expert"))
    return out_local, dropped_total

The function always computes and returns dropped_total regardless of _report_capacity_overflow. The branching happens in the outer moe_mlp function, which decides whether to surface the count to callers. The parameter could be removed from the inner function since it serves no purpose there. Not a bug — just dead code that adds a parameter without affecting behavior.

Otherwise this looks good. The functional kernel is well-tested (EP mesh, dense mesh, abstract mesh lowering, shape validation, activation dispatch, dropped count reporting), the template variant is self-contained per the grug copy-first philosophy, and the contract tests verify cross-variant consistency.


View job

@dlwh dlwh marked this pull request as ready for review March 1, 2026 08:26
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 1, 2026

Claude finished @dlwh's task in 2m 51s —— View job


  • Read changed files and gather context
  • Review grug_moe.py (functional MoE kernels)
  • Review experiments/grug/moe/ (template variant)
  • Review test_variant_contracts.py and test_grugformer_moe.py
  • Review remaining changes (activation, loss, base variant updates)
  • Post review

Review

Adds a functional MoE kernel (moe_mlp) in levanter, a self-contained MoE grug template variant under experiments/grug/moe/, variant contract tests, and dynamic batch-axis handling in the loss function. Fulfills the stated PR objectives.

Specification check: No linked spec or .agents/projects/ design doc. At ~2200 lines added, a specification would normally be expected per docs/recipes/agent-coding.md. The iterative review history on this PR serves as partial documentation.

Findings:

  • _report_capacity_overflow parameter is dead in _moe_mlp_ep_ring_local (grug_moe.py:130): The function always computes and returns dropped_total regardless of this parameter's value. The branching happens only in the outer moe_mlp. The parameter should be removed from the inner function signature (and from the partial call at line 281). Not a bug, just dead code that adds a misleading parameter.

  • sharding.py bug fix is correct but undocumented: The change from P((None,)) to P(None) fixes a real bug (P((None,)) creates PartitionSpec((None,),) -- a tuple-of-tuple, not replicated). Worth mentioning in the PR description since it affects the base variant too.

  • loss.py dynamic _batch_axis_spec is correct for EP: The change from hard-coded P(("data",)) to dynamically inferring the batch axis from hidden.sharding is necessary for the MoE variant's ("data", "expert") batch spec. The psum reduction and shard_map specs are updated consistently. The fallback to ("data",) when sharding info is absent is safe.

  • MoE next_token_loss eval path: moe/train.py:155 calls model.next_token_loss(...) without return_router_metrics, which defaults to False. This correctly avoids the router stats computation path during eval. No issue.

No other correctness issues found. Test coverage is solid -- the variant contract tests verify abstract mesh lowering and train-step shape parity across discovered variants, and test_grugformer_moe.py covers EP/non-EP/abstract mesh paths, activation dispatch, shape validation, and capacity overflow reporting.

dlwh added a commit that referenced this pull request Mar 2, 2026
## Context

Grug code is deliberately template-y/copypasta, which makes reviewing
diffs a bit tricky.

As a quick workaround, we made a tool to diff the nearest neighbor grug
variant

See a slightly faked version of the workflow in action here


#3046 (comment)

## Summary
- add a one-page HTML visual directory diff tool for Grug/template-style
code copies
- document usage in `experiments/grug/README.md`
- add CI helper to detect newly added Grug variants, find closest
existing variant, and generate diff reports
- add a PR workflow that publishes rendered reports to `gh-pages` and
comments links on the PR

## Testing
- `uvx ruff check scripts/grug_dir_diff.py
scripts/grug_variant_diff_ci.py tests/test_grug_dir_diff.py
tests/test_grug_variant_diff_ci.py`
- `uv run --with pytest --with pytest-timeout python -m pytest -q
tests/test_grug_dir_diff.py tests/test_grug_variant_diff_ci.py`
dlwh added a commit that referenced this pull request Mar 2, 2026
## Summary

This PR splits Grug variant contract work into its own change so it can
land before MoE implementation changes ( #3046 )

## Included
- Add `tests/test_grug_variant_contracts.py` as the contract suite for
Grug variants.
- Remove legacy `tests/test_grug_base_template.py`.
- Update `experiments/grug/base` to satisfy the contract surface:
  - `Transformer.next_token_loss(...)`
  - `debug_mesh_and_token_pspec(num_devices)`
  - `initial_state(...)` in `train.py`
- Update Grug docs to point to the new contract test path and document
the contract.

## Validation
- `uv run --with pytest-timeout pytest
tests/test_grug_variant_contracts.py -q`
- `uv run python infra/pre-commit.py --all-files`
@dlwh dlwh changed the title Add canonical Grug MoE module, template variant, and contracts Add canonical Grug MoE module and template variant Mar 3, 2026
@dlwh
Copy link
Copy Markdown
Member Author

dlwh commented Mar 3, 2026

@claude one last time?

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 3, 2026

Claude finished @dlwh's task in 2m 1s —— View job


🤖

  • Read changed files and gather context
  • Review grug_moe.py (functional MoE kernels)
  • Review experiments/grug/moe/ (template variant)
  • Review tests (test_grugformer_moe.py)
  • Review remaining changes (activation.py, loss.py, sharding.py, etc.)
  • Post final review

Review: Grug MoE Module and Template Variant (Final Pass)

Adds a functional MoE kernel (moe_mlp) in levanter with argsort-grouped dispatch and ring-style EP, a self-contained MoE grug template variant under experiments/grug/moe/, and supporting changes to loss, sharding, and activation handling. The code addresses all items raised in prior reviews.

No correctness issues found. This looks good to merge.

Changes verified:

  • grug_moe.py is now purely functional kernels (moe_mlp, _prepare_moe_dispatch, _moe_mlp_local, _moe_mlp_ep_ring_local). Routing stays in the caller. _report_capacity_overflow dead parameter was removed from inner functions per prior feedback.
  • loss.py dynamically infers batch axis spec from hidden.sharding via _batch_axis_spec, correctly handling both ("data",) and ("data", "expert") batch axes for the shard_map and psum reduction. The jax.typeof(x) fallback covers tracing contexts. Debug print statements removed.
  • sharding.py fixes a real bug: P((None,))P(None) (the former creates PartitionSpec((None,),) — a single-element tuple partition, not replicated).
  • activation.py adds relu2 to both TO_FN and TO_JAX_FN dicts and the enum, consistent with existing entries.
  • Loss function naming standardized to next_token_loss across both base and MoE variants.
  • model.py MoEMLP keeps shared expert weights replicated with an explicit comment explaining the decision.
  • Capacity overflow reporting returns a scalar dropped-assignment count via report_capacity_overflow=True, with correct psum over ("data", "expert") to produce a replicated global count.

Test coverage: Dense mesh, EP mesh, abstract mesh lowering, activation enum vs callable equivalence, and capacity overflow reporting. Sufficient for the kernel surface.


View job

Comment thread experiments/grug/moe/model.py Outdated
return mesh, P(("data", "expert"), None)


GrugMoeModelConfig = GrugModelConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of a scenario yet where I'll be importing under the name GrugMoeModelConfig.

levanter.tracker.log({"throughput/loading_time": iterator.this_load_time}, step=step)
router_metrics = {key: value for key, value in metrics.items() if key.startswith("train/router/")}
if router_metrics:
levanter.tracker.log(router_metrics, step=step)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like where the logging for MoE metrics lives now!

combine_weights = jax.nn.softmax(topk_logits, axis=-1).astype(x.dtype)
router_stats = _routing_stats_from_selected_experts(selected_experts, num_experts=self.cfg.num_experts)

routed_flat = moe_mlp(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really like having this be the interface level. At some point I think we can extend this to support non gated MLPs.

Comment thread experiments/grug/moe/model.py Outdated


@named_call
def shared_dense_mlp(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conceptually it might be simpler if MLP() is its own eqx class with up_gate/down, and then the MoEMLP has an attribute called shared_mlp of type MLP. This shared variable for gate/up may impact how Muon Orthogonalization behaves on the matrix- here it is super easy for the user to modify, but for the moe_mlp() function in levanter/grug we might want to think about how to enable specification of matrix slicing for orthogonalization or something of that sort.

@ClassicLarry
Copy link
Copy Markdown
Contributor

approving, no reason to block merge on these comments.

@dlwh dlwh enabled auto-merge (squash) March 3, 2026 20:45
@dlwh dlwh disabled auto-merge March 3, 2026 20:49
@dlwh dlwh merged commit 9181fd7 into main Mar 3, 2026
14 of 15 checks passed
@dlwh dlwh deleted the codex/grug-moe-ep-sensible-defaults branch March 3, 2026 20:49
@dlwh dlwh mentioned this pull request Mar 3, 2026
ruili33 pushed a commit that referenced this pull request Mar 25, 2026
## Context

Grug code is deliberately template-y/copypasta, which makes reviewing
diffs a bit tricky.

As a quick workaround, we made a tool to diff the nearest neighbor grug
variant

See a slightly faked version of the workflow in action here


#3046 (comment)

## Summary
- add a one-page HTML visual directory diff tool for Grug/template-style
code copies
- document usage in `experiments/grug/README.md`
- add CI helper to detect newly added Grug variants, find closest
existing variant, and generate diff reports
- add a PR workflow that publishes rendered reports to `gh-pages` and
comments links on the PR

## Testing
- `uvx ruff check scripts/grug_dir_diff.py
scripts/grug_variant_diff_ci.py tests/test_grug_dir_diff.py
tests/test_grug_variant_diff_ci.py`
- `uv run --with pytest --with pytest-timeout python -m pytest -q
tests/test_grug_dir_diff.py tests/test_grug_variant_diff_ci.py`
ruili33 pushed a commit that referenced this pull request Mar 25, 2026
## Summary

This PR splits Grug variant contract work into its own change so it can
land before MoE implementation changes ( #3046 )

## Included
- Add `tests/test_grug_variant_contracts.py` as the contract suite for
Grug variants.
- Remove legacy `tests/test_grug_base_template.py`.
- Update `experiments/grug/base` to satisfy the contract surface:
  - `Transformer.next_token_loss(...)`
  - `debug_mesh_and_token_pspec(num_devices)`
  - `initial_state(...)` in `train.py`
- Update Grug docs to point to the new contract test path and document
the contract.

## Validation
- `uv run --with pytest-timeout pytest
tests/test_grug_variant_contracts.py -q`
- `uv run python infra/pre-commit.py --all-files`
ruili33 pushed a commit that referenced this pull request Mar 25, 2026
## Summary
- add a canonical compact Grug MoE implementation in
`lib/levanter/src/levanter/grug/grug_moe.py`
- keep routing (`router matmul + top-k + softmax`) inline in
`MoEMLP.__call__`
- make `moe_mlp` the reusable dispatch/permute/unpermute (+EP) kernel
over precomputed routing selections/weights
- add template-first Grug MoE experiment surface under
`experiments/grug/moe/` (`model.py`, `train.py`, `launch.py`,
`__init__.py`) aligned with `experiments/grug/base`
- update activation wiring to use `ActivationFunctionEnum` consistently
and extend `levanter.utils.activation` with `relu2`

## Not In This PR
- Grug variant contract/test refactors were split and landed separately
in [#3169](#3169).

## Validation
- `uv run --python 3.11 --package levanter --group test pytest
lib/levanter/tests/grug/test_grugformer_moe.py -q`
- `uv run --python 3.11 --package levanter --group test pytest
lib/levanter/tests/grug -q`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants