Skip to content

[grug] Drive 90B-5.3BA MoE MFU on cw-us-east-02a from 3% to ~25% #6304

@rjpower

Description

@rjpower

Description

The 90B-total/5.3B-active Grug MoE from the cw-us-east-02a bringup (#6292, PR #6293) trains at 218K tok/s on 256 H100 — 2.8% active-FLOPs MFU (6·5.3e9 FLOPs/token ÷ 256·989 TFLOP/s). This issue drives it toward ~25% (≈2.0M tok/s) and is the running log of every experiment along the way. Complements the H100 MoE MFU epic #4302.

Diagnosis (repo mining + external research)

Structurally communication-bound. At global batch 256×2048, each GPU computes 2,048 tokens/step (~66 ms ideal bf16) while FSDP/ZeRO-3 over the 32-node IB axis gathers ~45–130 GB/GPU/step of parameters (90B total, fp32 master weights) against 400 Gb/s/GPU — 1–2.6 s of comm vs 0.07 s of compute. The observed 2.4 s step matches this arithmetic. Secondary suspects: reference attention on GPU (attention_implementation unset → materialized fp32 logits ×48 layers ×2 with remat; FA4 blocked by #6226), full-block remat (+20%, #5328), XLA fused-CE fallback (#5510).

Batch/seq scaling — the single biggest lever — is gated on the replicated-reshard pathology: XLA GSPMD reshards the [batch,seq,hidden] activation through a fully-replicated intermediate at the FSDP boundary (58 GiB tile OOM at 512×4096; see launch_cw_scale.py). Nothing in-repo has tried Shardy yet.

Plan (ranked; profile first, change one thing at a time)

  1. Baseline profile with the in-tree harness (ProfilerConfig + lib/marin/tools/profile_summary.py, per .agents/skills/profile-training/): split the step across comm / attention / dispatch / GMM / optimizer, per H100 x 8 MOE MFU perf #4302's method.
  2. Kill the replicated reshard: try Shardy (jax_use_shardy_partitioner; default partitioner since JAX 0.7.1, targets exactly this), else restructure _batch_reshard placement (PR [Grug] Reshard MoE block before MLP #6069 pattern) or fold it into the EP backend.
  3. Scale tokens/device 8–16× once unblocked (MaxText H100 configs run 8K tok/dev; OLMoE sustained ~19% MFU at 1.3B-active).
  4. XLA flag baseline + PGLE (latency-hiding scheduler, pipelined/combined collectives, cuBLAS GEMM, double buffering — MaxText A3 + JAX GPU perf-tips set). Note TPU flag sweeps were a negative result (Experiment: MaxText XLA flags for Grug MoE on v5p-64 #3399); GPU is untested.
  5. bf16 param gathers (fp32 master weights currently risk 2× FSDP bytes — verify in HLO whether convert sinks below the all-gather).
  6. FA4 attention on the CoreWeave image ([gpu] Make FA4 CuTe and THD work on CoreWeave #6226: install cutlass/flash-attn-4, set attention_implementation).
  7. EP topology + dispatch: hybrid ZeRO (shard only the 87B expert params cross-node), wider EP with resident experts (DeepSeek-V3 topology), assigned-token/DeepEP dispatch ([grug] Replace ring EP global buffers with assigned-token dispatch #6215, PR [Grug] Add assigned-token DeepEP MoE dispatch #6251). Caution: naive ragged_all_to_all was 1.71× slower than ring ([grug] Close B200 d5120/L8 MoE transport gap after #5815 #6139).
  8. Remat policy using the existing _CHECKPOINT_* names instead of bare eqx.filter_checkpoint (~+21.8% measured, Epic: make local Grug MoE GMM/MLP fast via SonicMoE-style kernels #5328); fused-CE block-size fix (Pallas fused cross-entropy falls back to XLA on CoreWeave H100 Grug MoE canary #5510); fp8/TE last.

Definition of Done

≥25% active-FLOPs MFU on 256 H100 (≈2.0M tok/s) at the 90B-5.3BA shape, or a documented ceiling with the remaining gap attributed (kernel vs comm vs geometry), with every experiment logged below.


Running log

Baseline (from bringup, 2026-06-09): d3072 L48 E128 top-4, batch 256 × seq 2048, mesh (data=32, expert=8), fp32 params / bf16 compute, ring EP, reference attention, full-block remat, no XLA flags. 218K tok/s = 2.8% MFU, step ~2.4 s. Canary (9.5B replicated per node, no cross-node FSDP) hit 1.39M tok/s on identical hardware — consistent with the comm-bound diagnosis.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions