Skip to content

Grug Demo, small scale feature maxed moe #3237

Closed
ClassicLarry wants to merge 12 commits intomainfrom
larrydial/grug_max_model_flop_mini
Closed

Grug Demo, small scale feature maxed moe #3237
ClassicLarry wants to merge 12 commits intomainfrom
larrydial/grug_max_model_flop_mini

Conversation

@ClassicLarry
Copy link
Copy Markdown
Contributor

Demoing the grug workflow with a feature-maxed moe small scale variant.

eval/paloma/c4_en/bpb: 1.1136 @ 5000 steps, ~9.14e17 model FLOPs.
MoE with pick 2 of 16 routed experts.
https://wandb.ai/marin-community/dial_moe/runs/updated_max_02_embed.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 4, 2026

🤖 Grug variant diff report

New Variant Closest Existing Variant Distance Score Diff
max_model_flop_mini moe 471 Open

Artifact fallback: Download report bundle

@dlwh
Copy link
Copy Markdown
Member

dlwh commented Mar 4, 2026

@claude thoughts?

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 4, 2026

Claude finished @dlwh's task in 1m 55s —— View job


Review of PR #3237: Grug Demo, small scale feature maxed MoE

  • Read changed files
  • Review grug variant files (model, train, launch)
  • Review grugmuon optimizer changes
  • Post review feedback

Overall this looks like a solid grug variant following the copy-first workflow. The architecture innovations are well-documented in the docstrings, the results are linked, and the grugmuon change to support 3D expert weights is clean. A few observations:

lib/levanter/src/levanter/optim/grugmuon.py — 3D Muon support

This is the most impactful change since it affects the shared optimizer used by all grug variants. The refactoring is well done:

  1. _newtonschulz_core extraction — Pulling the pure math into a vmap-safe function and keeping resharding in the caller (_zeropower_via_newtonschulz for 2D, inline reshard + vmap for 3D) is the right decomposition.

  2. No reshard-back for 3D — The old _zeropower_via_newtonschulz used to reshard back to orig_sharding after the Newton-Schulz iteration; the new 3D path does not. For 2D the new code also drops the reshard-back (just returns from _newtonschulz_core). This is fine since optax.apply_updates will handle the sharding, but it's a subtle behavioral change for existing 2D callers worth noting. Looks intentional and correct since the downstream scale multiplication and apply_updates don't depend on specific sharding.

  3. No tests for grugmuon — There are no unit tests for the optimizer itself (the contract test exercises optax.adam, not GrugMuonConfig). The 3D vmap path in particular would benefit from a test, even a simple filter_eval_shape check confirming the 3D branch lowers correctly. Not a blocker for this PR, but worth tracking.

experiments/grug/max_model_flop_mini/model.py

Clean, well-documented architecture. A few notes:

  • RMSNorm.init takes dim but doesn't use it (line 235-236) — The parameter-free RMSNorm ignores dim entirely. This is fine for the copy-first workflow (keeps the init signature compatible with other variants), but a comment noting this would help future readers.

  • value_lambda / ve_lambda initialized to 0.5 scalars (lines 181-182) — These are learnable scalars initialized to 0.5, so the VE contribution starts at 50%. Combined with the sigmoid gate (which starts at 2 * sigmoid(0) = 1.0), the initial VE contribution is 0.5 * 1.0 * ve_heads and original value is 0.5 * v. This seems intentional for a balanced start.

  • ve_gate and attn_gate shape (gate_input_dim, num_heads) — Using only the first gate_input_dim=12 hidden dimensions as gate input is an interesting choice (saves parameters). The x[..., :g] slicing in the forward pass (lines 219, 224) relies on gate_input_dim < hidden_dim, which is validated by the default but not explicitly checked in __post_init__.

  • Sliding window every 4th layer (line 512: i % 4 == 3) — The long window fires on layers 3, 7, etc. With num_layers=8 in the trial config, that's layers 3 and 7 getting the full window. Seems reasonable.

experiments/grug/max_model_flop_mini/train.py

This is a near-identical copy of the base/moe train.py, which is expected for the grug workflow. Nothing surprising here — the next_token_loss method on the model handles the auxiliary loss aggregation cleanly.

experiments/grug/max_model_flop_mini/launch.py

  • rope_theta=1024 in the trial config (line 99) — This is quite low compared to typical values (10000 for Llama, 500000 for some long-context models). With max_seq_len=2048 this should be fine, but it would limit extrapolation to longer sequences. Presumably intentional for this small-scale trial.

  • sliding_window=2048 with max_seq_len=2048 (line 98) — The sliding window equals the sequence length, which means the "short" window is 2048 // 2 = 1024 and the "long" window is the full sequence. So sliding window attention is effectively active here, just not limiting any layer below the full context. This makes sense as a template default.

  • weight_decay=0 — No weight decay at all, with Muon handling the 2D/3D parameters. This is a deliberate choice for this trial.

Contract tests

The existing test_grug_variant_one_step_contract_lowers_with_default_ctor in tests/test_grug_variant_contracts.py should automatically discover this new variant (it walks experiments/grug/*/model.py). The test uses GrugModelConfig(vocab_size=1024) with the default constructor, and the new config's __post_init__ validation should pass with defaults. 👍

Summary

The PR is a clean grug variant addition with a well-motivated optimizer extension. The grugmuon 3D support is the only change touching shared code and looks correct. No blocking issues — the main suggestion would be adding a grugmuon unit test (could be a follow-up).


Copy link
Copy Markdown
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally lg but the sharding stuff isn't quite there!

Keeping the experts sharded over the E axis should be relatively easy. I agree we need to replicate for the inner loop

from jax.sharding import PartitionSpec as P, reshard

# Replicate then vmap the Newton-Schulz core over the batch/expert dim
x = reshard(x, P(None, None, None))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho better to not replicate but just keep it vmapped over expert

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably best to do something like "replicate the last two dims but keep the first dim sharded"


assert X.ndim == 2
X = reshard(X, P(None, None))
return _newtonschulz_core(X, steps=steps, eps=eps, coefficient_type=coefficient_type)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to restore original sharding

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still need to fix this i think

@ClassicLarry
Copy link
Copy Markdown
Contributor Author

hmm tests are failing because they keep changing, 3 times in 2 days (Marin header, resource spec, v4-8 vmem). Thinking it will be simpler for me to hold off on chasing tests until manual review is done, then I do one round to sync with latest tests and merge asap.

@dlwh
Copy link
Copy Markdown
Member

dlwh commented Mar 5, 2026

sorry yeah lot of moving targets. i can help merge

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has been inactive for 23 days and is marked as stale.
If there is no further activity within 7 days, it will be automatically closed.
If you believe this PR should remain open, please add a comment or update the PR.

@github-actions github-actions Bot added the stale label Mar 31, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 7, 2026

This pull request has been automatically closed due to inactivity.
If you would like to continue working on this, please reopen it or create a new PR.

@github-actions github-actions Bot closed this Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants