Add MoeAdamHHeuristic, drop dense layers, fix align_kv_heads sharding#4636
Add MoeAdamHHeuristic, drop dense layers, fix align_kv_heads sharding#4636ClassicLarry wants to merge 5 commits intomainfrom
Conversation
- experiments/grug/moe/heuristic.py: port CompletedAdamHHeuristic from the moe_isoflop_apr_2026 branch, rename to MoeAdamHHeuristic. Adds compute_flops_per_token, compute_tokens_and_batch, and build_from_heuristic helpers so launch.py can derive (model, optimizer, batch, steps) from a compute budget + hidden_dim. - experiments/grug/moe/launch.py: use build_from_heuristic for the baseline ExecutorStep instead of hardcoding a GrugModelConfig and GrugMoeAdamHConfig. Manual specification is still supported by passing GrugModelConfig / GrugMoeAdamHConfig directly to GrugMoeLaunchConfig. - experiments/grug/moe/model.py: remove the initial-dense-layer path (_NUM_DENSE_LAYERS, dense_intermediate_dim, Block.dense_mlp, dense_only init branch) to match the isoflop architecture. Default EP capacity factor back to 1.0. Add the reshard on aligned_v that the isoflop branch had. - experiments/grug/moe/README.md: current best recipe, scaling heuristic summary, v16 isoflop best runs per budget, projections (L∞=1.6), and promotion criteria. - lib/levanter/src/levanter/grug/attention.py: replace jnp.repeat in align_kv_heads with a reshape+broadcast pattern. Fixes "Please pass sharding to jnp.repeat via out_sharding parameter" under abstract mesh contexts.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c333125126
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| model_cfg = h.build_model_config(hidden_dim) | ||
| fpt = compute_flops_per_token(model_cfg) |
There was a problem hiding this comment.
Propagate seq_len through heuristic model construction
build_from_heuristic accepts a seq_len override, but it builds the model with h.build_model_config(hidden_dim) and then computes FLOPs from that config, so FLOPs are still based on the hardcoded 4096 sequence length. If a caller passes a non-default seq_len, the returned (batch_size, num_steps) is computed with the override while model shape/FLOP estimation remains at 4096, causing the compute budget calculation to drift and mis-size experiments.
Useful? React with 👍 / 👎.
Layer counts in the v16 isoflop table were wrong (d768: 10→8, d1536: 14→16, d2048: 18→21). Verified against wandb run configs. Also switch URL encoding from %2B to literal + so the links resolve correctly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
CompletedAdamHHeuristicfrom themoe_isoflop_apr_2026branch intoexperiments/grug/moe/heuristic.pyasMoeAdamHHeuristic. Addscompute_flops_per_token,compute_tokens_and_batch, andbuild_from_heuristicsolaunch.pycan derive(model, optimizer, batch, steps)from a compute budget + hidden_dim.experiments/grug/moe/model.py(_NUM_DENSE_LAYERS,dense_intermediate_dim,Block.dense_mlp,dense_onlyinit branch) to match the isoflop MoE architecture. Default EP capacity factor back to 1.0. Restore thealigned_vreshard the isoflop branch had.experiments/grug/moe/launch.pyto usebuild_from_heuristicfor the baselineExecutorStepinstead of hardcodingGrugModelConfig/GrugMoeAdamHConfig. Manual specification is still supported by passing configs directly.experiments/grug/moe/README.mdcovering architecture, the scaling heuristic, v16 isoflop best runs per compute budget, projections (Paloma macro L∞ pinned at 1.6), and promotion criteria.lib/levanter/src/levanter/grug/attention.py::align_kv_heads: replacejnp.repeatwith a reshape+broadcast pattern. Resolves theValueError: Please pass sharding to jnp.repeat via out_sharding parametercrash under abstract-mesh training contexts.Test plan
isoflop-moe-v16-1e+18-d1024run to within ±0.005 at matching steps on a v5p-8 iris run (4_10_test_moe).build_from_heuristic(budget=1e18, hidden_dim=1024)reproduces the isoflop d1024 hyperparameters exactly (lr=0.01, adam_lr=0.002308, beta2=0.999, bs=32, steps=5622)../infra/pre-commit.py) passing.align_kv_headsreshape is equivalent to the oldjnp.repeatunder GQA.🤖 Generated with Claude Code