You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The 90B-total/5.3B-active Grug MoE from the cw-us-east-02a bringup (#6292, PR #6293) trains at 218K tok/s on 256 H100 — 2.8% active-FLOPs MFU (6·5.3e9 FLOPs/token ÷ 256·989 TFLOP/s). This issue drives it toward ~25% (≈2.0M tok/s) and is the running log of every experiment along the way. Complements the H100 MoE MFU epic #4302.
Diagnosis (repo mining + external research)
Structurally communication-bound. At global batch 256×2048, each GPU computes 2,048 tokens/step (~66 ms ideal bf16) while FSDP/ZeRO-3 over the 32-node IB axis gathers ~45–130 GB/GPU/step of parameters (90B total, fp32 master weights) against 400 Gb/s/GPU — 1–2.6 s of comm vs 0.07 s of compute. The observed 2.4 s step matches this arithmetic. Secondary suspects: reference attention on GPU (attention_implementation unset → materialized fp32 logits ×48 layers ×2 with remat; FA4 blocked by #6226), full-block remat (+20%, #5328), XLA fused-CE fallback (#5510).
Batch/seq scaling — the single biggest lever — is gated on the replicated-reshard pathology: XLA GSPMD reshards the [batch,seq,hidden] activation through a fully-replicated intermediate at the FSDP boundary (58 GiB tile OOM at 512×4096; see launch_cw_scale.py). Nothing in-repo has tried Shardy yet.
Plan (ranked; profile first, change one thing at a time)
Baseline profile with the in-tree harness (ProfilerConfig + lib/marin/tools/profile_summary.py, per .agents/skills/profile-training/): split the step across comm / attention / dispatch / GMM / optimizer, per H100 x 8 MOE MFU perf #4302's method.
Kill the replicated reshard: try Shardy (jax_use_shardy_partitioner; default partitioner since JAX 0.7.1, targets exactly this), else restructure _batch_reshard placement (PR [Grug] Reshard MoE block before MLP #6069 pattern) or fold it into the EP backend.
Scale tokens/device 8–16× once unblocked (MaxText H100 configs run 8K tok/dev; OLMoE sustained ~19% MFU at 1.3B-active).
XLA flag baseline + PGLE (latency-hiding scheduler, pipelined/combined collectives, cuBLAS GEMM, double buffering — MaxText A3 + JAX GPU perf-tips set). Note TPU flag sweeps were a negative result (Experiment: MaxText XLA flags for Grug MoE on v5p-64 #3399); GPU is untested.
bf16 param gathers (fp32 master weights currently risk 2× FSDP bytes — verify in HLO whether convert sinks below the all-gather).
≥25% active-FLOPs MFU on 256 H100 (≈2.0M tok/s) at the 90B-5.3BA shape, or a documented ceiling with the remaining gap attributed (kernel vs comm vs geometry), with every experiment logged below.
Running log
Baseline (from bringup, 2026-06-09): d3072 L48 E128 top-4, batch 256 × seq 2048, mesh (data=32, expert=8), fp32 params / bf16 compute, ring EP, reference attention, full-block remat, no XLA flags. 218K tok/s = 2.8% MFU, step ~2.4 s. Canary (9.5B replicated per node, no cross-node FSDP) hit 1.39M tok/s on identical hardware — consistent with the comm-bound diagnosis.
Description
The 90B-total/5.3B-active Grug MoE from the cw-us-east-02a bringup (#6292, PR #6293) trains at 218K tok/s on 256 H100 — 2.8% active-FLOPs MFU (6·5.3e9 FLOPs/token ÷ 256·989 TFLOP/s). This issue drives it toward ~25% (≈2.0M tok/s) and is the running log of every experiment along the way. Complements the H100 MoE MFU epic #4302.
Diagnosis (repo mining + external research)
Structurally communication-bound. At global batch 256×2048, each GPU computes 2,048 tokens/step (~66 ms ideal bf16) while FSDP/ZeRO-3 over the 32-node IB axis gathers ~45–130 GB/GPU/step of parameters (90B total, fp32 master weights) against 400 Gb/s/GPU — 1–2.6 s of comm vs 0.07 s of compute. The observed
2.4 s step matches this arithmetic. Secondary suspects: reference attention on GPU (+20%, #5328), XLA fused-CE fallback (#5510).attention_implementationunset → materialized fp32 logits ×48 layers ×2 with remat; FA4 blocked by #6226), full-block remat (Batch/seq scaling — the single biggest lever — is gated on the replicated-reshard pathology: XLA GSPMD reshards the [batch,seq,hidden] activation through a fully-replicated intermediate at the FSDP boundary (58 GiB tile OOM at 512×4096; see
launch_cw_scale.py). Nothing in-repo has tried Shardy yet.Plan (ranked; profile first, change one thing at a time)
ProfilerConfig+lib/marin/tools/profile_summary.py, per.agents/skills/profile-training/): split the step across comm / attention / dispatch / GMM / optimizer, per H100 x 8 MOE MFU perf #4302's method.jax_use_shardy_partitioner; default partitioner since JAX 0.7.1, targets exactly this), else restructure_batch_reshardplacement (PR [Grug] Reshard MoE block before MLP #6069 pattern) or fold it into the EP backend.attention_implementation)._CHECKPOINT_*names instead of bareeqx.filter_checkpoint(~+21.8% measured, Epic: make local Grug MoE GMM/MLP fast via SonicMoE-style kernels #5328); fused-CE block-size fix (Pallas fused cross-entropy falls back to XLA on CoreWeave H100 Grug MoE canary #5510); fp8/TE last.Definition of Done
≥25% active-FLOPs MFU on 256 H100 (≈2.0M tok/s) at the 90B-5.3BA shape, or a documented ceiling with the remaining gap attributed (kernel vs comm vs geometry), with every experiment logged below.
Running log
Baseline (from bringup, 2026-06-09): d3072 L48 E128 top-4, batch 256 × seq 2048, mesh (data=32, expert=8), fp32 params / bf16 compute, ring EP, reference attention, full-block remat, no XLA flags. 218K tok/s = 2.8% MFU, step ~2.4 s. Canary (9.5B replicated per node, no cross-node FSDP) hit 1.39M tok/s on identical hardware — consistent with the comm-bound diagnosis.