[CuTe DSL] Fix FP8 MLA persistent perf regression and ProxyKind cu13 wheel breakage#3132
[CuTe DSL] Fix FP8 MLA persistent perf regression and ProxyKind cu13 wheel breakage#3132pgera wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughRefactors MMA/MLA roles to accept call-time operand dtypes via new Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request refactors the attention MMA roles to use compile-time unrolled loops (cutlass.range_constexpr) instead of dynamic loops, aiming to improve throughput for the tcgen05 MMA dispatch. To prevent SSA value leakage across persistent loop boundaries during unrolling, the implementation now uses local TiledMma instances created within GEMM helper methods. Additionally, the PR introduces set_dtypes methods to provide necessary type information at call time and simplifies fence_proxy calls. I have no feedback to provide.
|
/bot run |
|
CI is green. |
…ode (persistent)
The FP8 MLA decode persistent path (`is_var_seq=False`) was running ~4 %
slower than trtllm-gen on B200 (geomean 0.96x across the standard
`seq_len=8192, num_heads=128` configs). The cause was two patterns in the
FP8 MMA role that prevented the cute-DSL JIT from emitting an
unrolled MMA dispatch loop with compile-time-constant ACCUMULATE bits.
1. Inner k-block loops in `_gemm_qk_latent_one_stage`,
`_gemm_qk_rope_one_stage`, and `_gemm_pv_one_stage` had been switched
from `cutlass.range_constexpr` to `cutlass.range` to work around an
SSA-dominance failure: when the helper's unrolled `tiled_mma.set()`
chain inlined into the persistent `while` loop, the chain rooted in
the caller's TiledMma SSA value and the loop's back-edge couldn't
pick a dominating value to thread back. The runtime-loop workaround
compiles but loses tcgen05 dispatch throughput.
Each helper now constructs its own local `TiledMma` via
`sm100_utils.make_trivial_tiled_mma(...)` and mutates that local
instance only. The unrolled chain dies inside the helper frame and
never escapes, so `cutlass.range_constexpr` can be restored. Same
pattern the compute role already uses for the same reason.
2. PV per-tile "first iteration?" state was tracked in a Python `bool`
(`pv_accumulated`) reassigned inside the dynamic k_tile loop. Cute-DSL
demotes such a variable to a runtime `i1` carried through the loop,
making the `accumulate or p_stage > 0` expression a runtime OR and
forcing every PV MMA to compute its ACCUMULATE bit at runtime.
Replace with `tiled_mma_pv.set(ACCUMULATE, …)` / `.get(ACCUMULATE)` so
the flag is encoded in the TiledMma's type-side metadata. The DSL
propagates that as a per-iteration compile-time constant, the OR
collapses, and each unrolled MMA gets its ACCUMULATE bit baked into
the opcode. This mirrors what BF16 `mla_mma.py` already does.
`mla_decode_fp8.py` adds `self.mma_role.set_dtypes(self.q_dtype, self.v_dtype)`
so the helpers know the operand types when constructing the local TiledMma.
Verified at the IR/SASS level on B200, FP8 MLA decode persistent kernel:
cubin 156 296 B → 126 648 B (-19 %)
scf.for (k-block runtime loops) gone
scf.while carry: drops one i1 (`pv_accumulated`) and 11 i32 helpers,
gains one !mma_*_128x256x32_ (PV TiledMma)
UTCQMMA SASS instructions 103 → 52 (-50 %)
BSSY/BSYNC pairs 88 → 6 (-93 %)
VOTEU 241 → 1 (-99 %)
WARPSYNC 67 → 2 (-97 %)
Perf — Table 1 (FP8 fixed-len, persistent), B200, median of 3 runs,
CUPTI timing, CUDA graph, cold L2:
Config cute_base cute_fix sp_base sp_fix cute Δ%
B= 1, q=1 0.0158 0.0158 1.00x 0.99x +0.1 %
B= 32, q=1 0.0547 0.0527 0.96x 0.99x -3.7 %
B= 64, q=1 0.0795 0.0738 0.99x 1.06x -7.2 %
B=128, q=1 0.1494 0.1392 0.97x 1.04x -6.8 %
B=256, q=1 0.3020 0.2786 0.94x 1.03x -7.8 %
B= 1, q=4 0.0189 0.0181 1.00x 1.03x -3.9 %
B= 32, q=4 0.1447 0.1272 0.91x 1.04x -12.1 %
B= 64, q=4 0.2810 0.2521 0.93x 1.04x -10.3 %
B=128, q=4 0.5069 0.4691 0.95x 1.03x -7.5 %
B=256, q=4 1.0456 0.9553 0.95x 1.03x -8.6 %
geomean 0.958x 1.028x -6.8 %
cute-dsl now beats trtllm-gen on 8 / 10 configs (was 0 / 10).
Tables 2 (BF16 fixed-len), 3 (FP8 var-seq) and 4 (BF16 var-seq) re-run
as a sanity check — all unchanged within run-to-run noise.
Tests: `tests/attention/test_cute_dsl_mla_decode.py` — 297 passed,
including all 24 FP8-specific cases.
Made-with: Cursor
…P16 helpers Apply the same fresh-local-TiledMma + range_constexpr pattern to the BF16/FP16 MLA decode MMA role that the FP8 variant just got in the previous commit. Each GEMM helper (`_gemm_qk_latent_one_stage`, `_gemm_qk_rope_one_stage`, `_gemm_pv_one_stage`) now constructs its own local `TiledMma` via `sm100_utils.make_trivial_tiled_mma(...)` and mutates that local instance inside a `cutlass.range_constexpr` k-block loop. The helper's parameter TiledMma is no longer touched, so the unrolled `.set(ACCUMULATE, ...)` chain dies inside the helper frame and never escapes into the persistent tile-scheduler `while` loop's SSA carry. This removes the workaround the PR introducing the modular kernel had to use (`cutlass.range` instead of `range_constexpr`) and the misleading "SSA leak" warning that came with it. `mla_decode.py` adds `self.mma_role.set_dtypes(self.q_dtype, self.v_dtype)` so the helpers know the operand types when constructing the local TiledMma (same plumbing the FP8 variant uses). This is a pure cleanup — for BF16/FP16 the SASS is bit-identical before and after on B200 (the back-end was already collapsing the `scf.for` loops on its own for this dtype), so wall-clock perf doesn't move: Table 2 (BF16 fixed-len, persistent), B200, median of 3 runs: cute geomean Δ = +0.12 % (within run-to-run noise) speedup vs trtllm-gen: 1.045x → 1.048x (effectively unchanged) cubin 158 088 B → 157 968 B (-120 B) SASS: 10 176 lines → 10 176 lines (identical) UTCHMMA / BSSY / VOTEU / WARPSYNC / IMAD counts: identical The IR-level signal is real though — six runtime `scf.for` loops carrying TiledMma values through the inner k-block loops are gone (40 → 34 total `scf.for` ops in the kernel), matching what we saw in the FP8 fix. The BF16 back-end happens to optimize the buggy form to the same SASS, but keeping FP8 and BF16 helpers structurally identical means a future cute-DSL update that changes those back-end heuristics won't silently regress one variant and not the other. Tests: `tests/attention/test_cute_dsl_mla_decode.py` — 297 passed. Made-with: Cursor
…ll helpers Apply the same fresh-local-TiledMma + range_constexpr pattern that the FP8 and FP16/BF16 MLA decode kernels just got to the FMHA prefill MMA role, plus the matching `pv_tiled_mma.set/get(ACCUMULATE)` carry that replaces a Python `pv_whether_acc` bool reassigned inside the dynamic kv loop. In `roles/mma.py`: * `gemm_qk` and `gemm_pv` now construct local TiledMma instances via `sm100_utils.make_trivial_tiled_mma(...)` and mutate those locals only. Their inner kphase loops switch back from `cutlass.range(unroll_full=True)` to `cutlass.range_constexpr` so the unrolled `.set(ACCUMULATE, ...)` chain dies inside the helper's frame. The caller's TiledMma is no longer touched, so the unused `tiled_mma` parameter is dropped from both helpers. * The dead `qk_tiled_mma` is also dropped from `MmaRole.run()` itself (it was only forwarded to `gemm_qk`). `pv_tiled_mma` stays because `run()` now uses it for the per-tile `set(ACCUMULATE, False)` / `get(ACCUMULATE)` / `set(ACCUMULATE, True)` triplet around the kv loop, replacing the `pv_whether_acc` Python bool. This mirrors the BF16 MLA `mla_mma.py` pattern and lets the cute-DSL JIT propagate the ACCUMULATE bit as a per-iteration compile-time constant. * New `set_dtypes(q_dtype, v_dtype, q_major_mode, k_major_mode, v_major_mode)` method gives the helpers the metadata they need to call `make_trivial_tiled_mma`. `prefill.py` calls it right after constructing the role. In `roles/mla_mma.py` and `roles/mla_mma_fp8.py`: * Drop the unused `tiled_mma_qk` / `tiled_mma_pv` parameters from the three GEMM helpers (`_gemm_qk_latent_one_stage`, `_gemm_qk_rope_one_stage`, `_gemm_pv_one_stage`). `run()` keeps these names — they're still needed there for fragment creation and, for the PV variant, the ACCUMULATE carry. For BF16 prefill the wall-clock perf is unchanged within run-to-run noise — same outcome as the BF16 MLA refactor. IR is cleaner (inner runtime k-block `scf.for` ops drop from 8 to 1, total `scf.for` from 39 to 31) but the Blackwell back-end was already collapsing them to identical SASS, so cubin size only shrinks by 128 B and the per-config TFLOPS measurements are flat (geomean −0.4 %, well within noise on the single-run bench): benchmarks/bench_blackwell_attention_cutedsl.py (BF16 causal): Config (B, S) base TFLOPS fix TFLOPS Δ (128, 512) 238.358 237.193 -0.5 % ( 64, 1024) 348.229 346.865 -0.4 % ( 32, 2048) 544.895 540.885 -0.7 % ( 16, 4096) 721.215 715.552 -0.8 % ( 8, 8192) 898.808 896.179 -0.3 % ( 4, 16384) 1000.544 999.480 -0.1 % ( 2, 32768) 1065.185 1059.602 -0.5 % ( 1, 65536) 1103.561 1102.442 -0.1 % So this lands as a defensive cleanup rather than a perf win for the BF16 prefill path: aligns the prefill MMA role with the MLA roles' modern pattern, removes the misleading "use cutlass.range to dodge SSA leaks" workaround comments, kills three sets of dead `tiled_mma` parameters, and protects against future cute-DSL back-end heuristic changes that could regress the buggy form (as happened with FP8 MLA). Tests: `test_modular_fmha_prefill.py` 159 passed / 20 skipped; `test_cute_dsl_mla_decode.py` 297 passed. Made-with: Cursor
…eel compat) Fixes flashinfer-ai#3071. `cute.arch.fence_proxy` is documented as taking string literals (``"async.shared"``, ``"cta"``); the wrapper does ``ProxyKind.from_str`` / ``SharedSpace.from_str`` internally. The enum re-exports ``cute.arch.ProxyKind`` and ``cute.arch.SharedSpace`` are gated upstream behind ``cutlass_dsl.target_version(exact_version="12.9")`` and so are absent from the cu13 wheel that the ``flashinfer-ci-cu130`` docker image uses, producing: AttributeError: module 'cutlass.cute.arch' has no attribute 'ProxyKind' at correction.py:267 and softmax.py:448. On cu12 wheels the enum form still works but emits ``DeprecationWarning: Passing enum member directly to SharedSpace.from_str() is deprecated. Please use string literals instead`` (visible in our local test runs). Switching the two call sites to the string form fixes both the cu13 breakage and the cu12 deprecation warning, and is the documented stable API on every flashinfer-supported cutlass-dsl 4.4.x wheel. Verified: python -m pytest tests/attention/test_modular_fmha_prefill.py \ tests/attention/test_cute_dsl_mla_decode.py → 456 passed, 20 skipped (no SharedSpace/ProxyKind warnings emitted). Made-with: Cursor
e15ae4e to
16574c1
Compare
Summary
Two functional fixes plus two defensive refactors on the modular CuTe DSL attention kernels added in #2805:
perf(cute_dsl/mla)— recover a ~6.4 % geomean perf regression on FP8 MLA decode persistent (is_var_seq=False) versus trtllm-gen on B200, by unblocking the cute-DSL JIT's constant folding around the MMA dispatch loop. Cute-DSL now beats trtllm-gen on 9/10 of the standardseq_len=8192, num_heads=128configs (was 0/10).fix(cute_dsl/attention)— closes #3071: thecu13cutlass-dsl wheel doesn't re-exportcute.arch.ProxyKind/cute.arch.SharedSpace(gated upstream behindcutlass_dsl.target_version("12.9")), causingAttributeErroratcorrection.py:267andsoftmax.py:448on the CI'sflashinfer-ci-cu130image. Switching to the documented string-literal API (fence_proxy("async.shared", space="cta")) works on both cu12.9 and cu13 wheels and silences the existingPassing enum member directly to SharedSpace.from_str() is deprecatedwarning on cu12.9.refactor(cute_dsl/mla)+refactor(cute_dsl/fmha)— apply the same fresh-local-TiledMma+range_constexprpattern that the FP8 perf fix introduced to the BF16/FP16 MLA decode and FMHA prefill MMA roles. Wall-clock perf is unchanged for these (the back-end was already collapsing the buggy form to identical SASS for these dtypes/shapes), but the IR is cleaner and the helpers' API is consistent across all three MMA roles, protecting against future cute-DSL back-end heuristic changes that could regress one variant the way FP8 did.Commits
ba59eb61perf(cute_dsl/mla): unblock MMA-warp constant folding for FP8 MLA decode (persistent)4e66703crefactor(cute_dsl/mla): isolate MMA-warp TiledMma mutations in BF16/FP16 helpers6e67879drefactor(cute_dsl/fmha): isolate MMA-warp TiledMma mutations in prefill helperse15ae4eafix(cute_dsl/attention): use string literals for fence_proxy (cu13 wheel compat)Root cause analysis (FP8 perf regression)
PR #2805 introduced two patterns in the FP8 MMA role that prevented the cute-DSL JIT from emitting an unrolled MMA dispatch loop with compile-time-constant
ACCUMULATEbits:The inner k-block loops in
_gemm_qk_latent_one_stage,_gemm_qk_rope_one_stage, and_gemm_pv_one_stagehad been switched fromcutlass.range_constexprtocutlass.range. The original PR's comment explained this as a workaround for an SSA-dominance failure: when the helper's unrolledtiled_mma.set(ACCUMULATE, ...)chain inlined into the persistentwhileloop, the chain rooted in the caller'sTiledMmaSSA value and the loop's back-edge couldn't pick a dominating value to thread back. The runtime-loop workaround compiles but loses tcgen05 dispatch throughput. Fixed by having each helper construct its own localTiledMmaviasm100_utils.make_trivial_tiled_mma(...)and mutate that local instance only — the unrolled chain now dies inside the helper frame andrange_constexpris safe again. Same pattern the compute role already uses.PV per-tile "first-iteration overwrite, then accumulate" state was tracked in a Python
bool(pv_accumulated) reassigned inside the dynamic k_tile loop. The cute-DSL JIT demoted that bool to a runtimei1carried through the loop, makingaccumulate or p_stage > 0a runtime OR and forcing every PV MMA to compute itsACCUMULATEbit at runtime. Fixed by storing the flag ontiled_mma_pv.set(ACCUMULATE, ...)/.get(ACCUMULATE)so the field becomes type-side metadata that the JIT propagates as a per-iteration compile-time constant — same pattern the BF16mla_mma.pyalready used.Verification
FP8 MLA persistent perf (Table 1 from PR #2743's reproducer, B200, median of 3 runs, CUPTI + CUDA graph + cold L2)
Geomean speedup vs trtllm-gen flipped from 0.958x → 1.029x. cute-dsl now beats trtllm-gen on 9/10 configs.
IR / SASS evidence (FP8 MLA persistent kernel)
UTCQMMA(FP8 MMA dispatch)BSSY/BSYNC(structured-sync regions)VOTEU(warp predicate votes)WARPSYNCscf.whileouter carry: TiledMma typei1)scf.whileouter carry:i1(Python bool)scf.whileouter carry:i32Other paths (defensive refactors — unchanged perf, same as predicted)
benchmarks/bench_blackwell_attention_cutedsl.py): cute geomean Δ −0.31 % (noise)ProxyKind / cu13 wheel (verified on both wheel variants)
target_version("12.9")cute.arch.ProxyKindlibs-base==4.4.2)libs-base + libs-cu13==4.4.2)Test plan
tests/attention/test_cute_dsl_mla_decode.py— 297 passed (cu12.9), includes 24 FP8-specific casestests/attention/test_modular_fmha_prefill.py— 159 passed, 20 skipped (cu12.9)bench_pr2743_reproduce.py(Tables 1-4, 3 runs each for FP8 fixed-len) — see numbers abovebenchmarks/bench_blackwell_attention_cutedsl.py— perf flat as expectedMade with Cursor
Summary by CodeRabbit