Skip to content

Add MoeAdamHHeuristic, drop dense layers, fix align_kv_heads sharding#4636

Open
ClassicLarry wants to merge 5 commits intomainfrom
grug_moe_heuristic
Open

Add MoeAdamHHeuristic, drop dense layers, fix align_kv_heads sharding#4636
ClassicLarry wants to merge 5 commits intomainfrom
grug_moe_heuristic

Conversation

@ClassicLarry
Copy link
Copy Markdown
Contributor

Summary

  • Port CompletedAdamHHeuristic from the moe_isoflop_apr_2026 branch into experiments/grug/moe/heuristic.py as MoeAdamHHeuristic. Adds compute_flops_per_token, compute_tokens_and_batch, and build_from_heuristic so launch.py can derive (model, optimizer, batch, steps) from a compute budget + hidden_dim.
  • Remove the initial-dense-layer path from experiments/grug/moe/model.py (_NUM_DENSE_LAYERS, dense_intermediate_dim, Block.dense_mlp, dense_only init branch) to match the isoflop MoE architecture. Default EP capacity factor back to 1.0. Restore the aligned_v reshard the isoflop branch had.
  • Rewire experiments/grug/moe/launch.py to use build_from_heuristic for the baseline ExecutorStep instead of hardcoding GrugModelConfig / GrugMoeAdamHConfig. Manual specification is still supported by passing configs directly.
  • Add experiments/grug/moe/README.md covering architecture, the scaling heuristic, v16 isoflop best runs per compute budget, projections (Paloma macro L∞ pinned at 1.6), and promotion criteria.
  • Fix lib/levanter/src/levanter/grug/attention.py::align_kv_heads: replace jnp.repeat with a reshape+broadcast pattern. Resolves the ValueError: Please pass sharding to jnp.repeat via out_sharding parameter crash under abstract-mesh training contexts.

Test plan

  • Verified loss curves match the published isoflop-moe-v16-1e+18-d1024 run to within ±0.005 at matching steps on a v5p-8 iris run (4_10_test_moe).
  • build_from_heuristic(budget=1e18, hidden_dim=1024) reproduces the isoflop d1024 hyperparameters exactly (lr=0.01, adam_lr=0.002308, beta2=0.999, bs=32, steps=5622).
  • Pre-commit linter + type check (./infra/pre-commit.py) passing.
  • Reviewer: confirm the align_kv_heads reshape is equivalent to the old jnp.repeat under GQA.

🤖 Generated with Claude Code

- experiments/grug/moe/heuristic.py: port CompletedAdamHHeuristic from the
  moe_isoflop_apr_2026 branch, rename to MoeAdamHHeuristic. Adds
  compute_flops_per_token, compute_tokens_and_batch, and build_from_heuristic
  helpers so launch.py can derive (model, optimizer, batch, steps) from a
  compute budget + hidden_dim.
- experiments/grug/moe/launch.py: use build_from_heuristic for the baseline
  ExecutorStep instead of hardcoding a GrugModelConfig and GrugMoeAdamHConfig.
  Manual specification is still supported by passing GrugModelConfig /
  GrugMoeAdamHConfig directly to GrugMoeLaunchConfig.
- experiments/grug/moe/model.py: remove the initial-dense-layer path
  (_NUM_DENSE_LAYERS, dense_intermediate_dim, Block.dense_mlp, dense_only
  init branch) to match the isoflop architecture. Default EP capacity factor
  back to 1.0. Add the reshard on aligned_v that the isoflop branch had.
- experiments/grug/moe/README.md: current best recipe, scaling heuristic
  summary, v16 isoflop best runs per budget, projections (L∞=1.6), and
  promotion criteria.
- lib/levanter/src/levanter/grug/attention.py: replace jnp.repeat in
  align_kv_heads with a reshape+broadcast pattern. Fixes
  "Please pass sharding to jnp.repeat via out_sharding parameter" under
  abstract mesh contexts.
@ClassicLarry ClassicLarry added the agent-generated Created by automation/agent label Apr 10, 2026
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c333125126

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +259 to +260
model_cfg = h.build_model_config(hidden_dim)
fpt = compute_flops_per_token(model_cfg)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Propagate seq_len through heuristic model construction

build_from_heuristic accepts a seq_len override, but it builds the model with h.build_model_config(hidden_dim) and then computes FLOPs from that config, so FLOPs are still based on the hardcoded 4096 sequence length. If a caller passes a non-default seq_len, the returned (batch_size, num_steps) is computed with the override while model shape/FLOP estimation remains at 4096, causing the compute budget calculation to drift and mis-size experiments.

Useful? React with 👍 / 👎.

ClassicLarry and others added 4 commits April 10, 2026 13:45
Layer counts in the v16 isoflop table were wrong (d768: 10→8, d1536: 14→16,
d2048: 18→21). Verified against wandb run configs. Also switch URL encoding
from %2B to literal + so the links resolve correctly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

agent-generated Created by automation/agent

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant