Skip to content

Commit f00ce36

Browse files
authored
[CuTe DSL] Fix FP8 MLA persistent perf regression and ProxyKind cu13 wheel breakage (#3132)
## Summary Two functional fixes plus two defensive refactors on the modular CuTe DSL attention kernels added in #2805: 1. **`perf(cute_dsl/mla)`** — recover a ~6.4 % geomean perf regression on FP8 MLA decode persistent (`is_var_seq=False`) versus trtllm-gen on B200, by unblocking the cute-DSL JIT's constant folding around the MMA dispatch loop. Cute-DSL now beats trtllm-gen on 9/10 of the standard `seq_len=8192, num_heads=128` configs (was 0/10). 2. **`fix(cute_dsl/attention)`** — closes [#3071](#3071): the `cu13` cutlass-dsl wheel doesn't re-export `cute.arch.ProxyKind` / `cute.arch.SharedSpace` (gated upstream behind `cutlass_dsl.target_version("12.9")`), causing `AttributeError` at `correction.py:267` and `softmax.py:448` on the CI's `flashinfer-ci-cu130` image. Switching to the documented string-literal API (`fence_proxy("async.shared", space="cta")`) works on both cu12.9 and cu13 wheels and silences the existing `Passing enum member directly to SharedSpace.from_str() is deprecated` warning on cu12.9. 3. **`refactor(cute_dsl/mla)`** + **`refactor(cute_dsl/fmha)`** — apply the same fresh-local-`TiledMma` + `range_constexpr` pattern that the FP8 perf fix introduced to the BF16/FP16 MLA decode and FMHA prefill MMA roles. Wall-clock perf is unchanged for these (the back-end was already collapsing the buggy form to identical SASS for these dtypes/shapes), but the IR is cleaner and the helpers' API is consistent across all three MMA roles, protecting against future cute-DSL back-end heuristic changes that could regress one variant the way FP8 did. ## Commits | SHA | What | |---|---| | `ba59eb61` | `perf(cute_dsl/mla): unblock MMA-warp constant folding for FP8 MLA decode (persistent)` | | `4e66703c` | `refactor(cute_dsl/mla): isolate MMA-warp TiledMma mutations in BF16/FP16 helpers` | | `6e67879d` | `refactor(cute_dsl/fmha): isolate MMA-warp TiledMma mutations in prefill helpers` | | `e15ae4ea` | `fix(cute_dsl/attention): use string literals for fence_proxy (cu13 wheel compat)` | ## Root cause analysis (FP8 perf regression) PR #2805 introduced two patterns in the FP8 MMA role that prevented the cute-DSL JIT from emitting an unrolled MMA dispatch loop with compile-time-constant `ACCUMULATE` bits: 1. The inner k-block loops in `_gemm_qk_latent_one_stage`, `_gemm_qk_rope_one_stage`, and `_gemm_pv_one_stage` had been switched from `cutlass.range_constexpr` to `cutlass.range`. The original PR's comment explained this as a workaround for an SSA-dominance failure: when the helper's unrolled `tiled_mma.set(ACCUMULATE, ...)` chain inlined into the persistent `while` loop, the chain rooted in the caller's `TiledMma` SSA value and the loop's back-edge couldn't pick a dominating value to thread back. The runtime-loop workaround compiles but loses tcgen05 dispatch throughput. Fixed by having each helper construct its own local `TiledMma` via `sm100_utils.make_trivial_tiled_mma(...)` and mutate that local instance only — the unrolled chain now dies inside the helper frame and `range_constexpr` is safe again. Same pattern the compute role already uses. 2. PV per-tile "first-iteration overwrite, then accumulate" state was tracked in a Python `bool` (`pv_accumulated`) reassigned inside the dynamic k_tile loop. The cute-DSL JIT demoted that bool to a runtime `i1` carried through the loop, making `accumulate or p_stage > 0` a runtime OR and forcing every PV MMA to compute its `ACCUMULATE` bit at runtime. Fixed by storing the flag on `tiled_mma_pv.set(ACCUMULATE, ...)` / `.get(ACCUMULATE)` so the field becomes type-side metadata that the JIT propagates as a per-iteration compile-time constant — same pattern the BF16 `mla_mma.py` already used. ## Verification ### FP8 MLA persistent perf (Table 1 from PR #2743's reproducer, B200, median of 3 runs, CUPTI + CUDA graph + cold L2) | Config | trt | cute_base | cute_fix | sp_base | sp_fix | cute Δ% | | :----------- | -----: | --------: | -------: | ------: | -----: | -------: | | B= 1, q=1 | 0.0157 | 0.0158 | 0.0161 | 1.00x | 0.99x | +1.9 % | | B= 32, q=1 | 0.0521 | 0.0547 | 0.0522 | 0.96x | 1.00x | −4.7 % | | B= 64, q=1 | 0.0784 | 0.0795 | 0.0734 | 0.99x | 1.06x | −7.7 % | | B=128, q=1 | 0.1448 | 0.1494 | 0.1391 | 0.97x | 1.04x | −6.9 % | | B=256, q=1 | 0.2890 | 0.3020 | 0.2808 | 0.94x | 1.03x | −7.0 % | | B= 1, q=4 | 0.0190 | 0.0189 | 0.0181 | 1.00x | 1.05x | −3.8 % | | B= 32, q=4 | 0.1314 | 0.1447 | 0.1257 | 0.91x | 1.05x | −13.2 % | | B= 64, q=4 | 0.2627 | 0.2810 | 0.2540 | 0.93x | 1.03x | −9.6 % | | B=128, q=4 | 0.4905 | 0.5069 | 0.4753 | 0.95x | 1.03x | −6.2 % | | B=256, q=4 | 1.0050 | 1.0456 | 0.9869 | 0.95x | 1.02x | −5.6 % | | **geomean** | | | | **0.958x** | **1.029x** | **−6.4 %** | Geomean speedup vs trtllm-gen flipped from **0.958x → 1.029x**. cute-dsl now beats trtllm-gen on 9/10 configs. ### IR / SASS evidence (FP8 MLA persistent kernel) | Metric | Buggy | Fixed | | :--- | ---: | ---: | | cubin size | 156 296 B | **126 648 B (−19 %)** | | `UTCQMMA` (FP8 MMA dispatch) | 103 | **52 (−50 %)** | | `BSSY` / `BSYNC` (structured-sync regions) | 88 / 88 | **6 / 6 (−93 %)** | | `VOTEU` (warp predicate votes) | 241 | **1 (−99 %)** | | `WARPSYNC` | 67 | **2 (−97 %)** | | `scf.while` outer carry: TiledMma type | 0 | **1** (PV TiledMma replaces the demoted `i1`) | | `scf.while` outer carry: `i1` (Python bool) | 2 | **0** | | `scf.while` outer carry: `i32` | 31 | **20** (intermediate state from inner runtime loops gone) | ### Other paths (defensive refactors — unchanged perf, same as predicted) - BF16 MLA persistent: cute geomean Δ +0.12 % (single-run noise, SASS bit-identical) - FP8 MLA var-seq: cute geomean Δ +0.52 % (noise) - BF16 MLA var-seq: cute geomean Δ +0.79 % (noise) - BF16 FMHA prefill (`benchmarks/bench_blackwell_attention_cutedsl.py`): cute geomean Δ −0.31 % (noise) ### ProxyKind / cu13 wheel (verified on both wheel variants) | Wheel | `target_version("12.9")` | `cute.arch.ProxyKind` | Pre-fix code | Post-fix code | |---|:---:|:---:|:---:|:---:| | **cu12.9** (`libs-base==4.4.2`) | True | exists (with deprecation warning) | passes (warning) | **passes** | | **cu13** (`libs-base + libs-cu13==4.4.2`) | False | **AttributeError** | **fails** (#3071, exact bit-identical reproduction) | **passes** | ## Test plan - [x] `tests/attention/test_cute_dsl_mla_decode.py` — 297 passed (cu12.9), includes 24 FP8-specific cases - [x] `tests/attention/test_modular_fmha_prefill.py` — 159 passed, 20 skipped (cu12.9) - [x] Same test suites on cu13 wheel — 456 passed, 20 skipped (closes #3071 cleanly) - [x] `bench_pr2743_reproduce.py` (Tables 1-4, 3 runs each for FP8 fixed-len) — see numbers above - [x] `benchmarks/bench_blackwell_attention_cutedsl.py` — perf flat as expected - [x] Pre-commit (mypy, ruff, etc.) — passes for all 4 commits Made with [Cursor](https://cursor.com) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Attention path now explicitly configures operand data types at runtime for more consistent numeric behavior (including FP8 paths). * Matrix-multiply helpers now build local compute instances, reducing external coupling and simplifying orchestration. * PV accumulation logic streamlined for clearer accumulate semantics. * **Bug Fix** * Tightened post-shared-memory synchronization to improve stability and determinism. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent fb3bb44 commit f00ce36

8 files changed

Lines changed: 290 additions & 92 deletions

File tree

flashinfer/cute_dsl/attention/mla_decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def _reinterpret_3d_kv(t):
188188
self.pt_loader_role = MLAPageTableLoaderRole(self.config)
189189
self.loader_role = MLALoaderRole(self.config)
190190
self.mma_role = MLAMmaRole(self.config, self.mainloop)
191+
self.mma_role.set_dtypes(self.q_dtype, self.v_dtype)
191192
self.compute_role = MLAComputeRole(self.config, fusion=self.fusion)
192193
self.compute_role.set_dtypes(self.q_dtype)
193194
self.compute_role.set_barriers(self.softmax_exchange_sync_bar)

flashinfer/cute_dsl/attention/mla_decode_fp8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _reinterpret_3d_kv(t):
191191
self.loader_k_role = MLAFP8LoaderKRole(self.config)
192192
self.loader_v_role = MLAFP8LoaderVRole(self.config)
193193
self.mma_role = MLAMmaFP8Role(self.config, self.mainloop)
194+
self.mma_role.set_dtypes(self.q_dtype, self.v_dtype)
194195
self.compute_role = MLAComputeRole(self.config, fusion=self.fusion)
195196
self.compute_role.set_dtypes(self.q_dtype)
196197
self.compute_role.set_barriers(self.softmax_exchange_sync_bar)

flashinfer/cute_dsl/attention/prefill.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ def __call__(
206206
threads_per_warp=self.schedule.threads_per_warp,
207207
has_logits_transform=self.has_logits_transform,
208208
)
209+
self.mma_role.set_dtypes(
210+
self.q_dtype,
211+
self.v_dtype,
212+
self.q_major_mode,
213+
self.k_major_mode,
214+
self.v_major_mode,
215+
)
209216
self.softmax_role.set_dtypes(self.q_dtype, self.o_dtype)
210217

211218
lp = build_fmha_launch_params(
@@ -495,7 +502,6 @@ def kernel(
495502
if warp_idx == self.schedule.mma_warp_id:
496503
cute.arch.warpgroup_reg_dealloc(self.schedule.num_regs_other)
497504
self.mma_role.run(
498-
qk_tiled_mma,
499505
pv_tiled_mma,
500506
tStS0,
501507
tStS1,

flashinfer/cute_dsl/attention/roles/correction.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,7 @@ def epilog(
263263
cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i)
264264

265265
# fence view async shared
266-
cute.arch.fence_proxy(
267-
cute.arch.ProxyKind.async_shared,
268-
space=cute.arch.SharedSpace.shared_cta,
269-
)
266+
cute.arch.fence_proxy("async.shared", space="cta")
270267

271268
@cute.jit
272269
def run(

flashinfer/cute_dsl/attention/roles/mla_mma.py

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,22 @@
1919
each helper as ``k_block != 0 or accumulate``, making it deterministic from
2020
the parameter and inner loop index. Callers compute the parameter from their
2121
own loop position — never from ``tiled_mma.get()`` after a sub-method return.
22+
23+
Inner k-block loops use ``cutlass.range_constexpr`` (compile-time unrolled)
24+
for maximum tcgen05 MMA dispatch throughput. Each GEMM helper constructs
25+
its own local ``TiledMma`` via ``make_trivial_tiled_mma`` so that the
26+
``.set(ACCUMULATE, ...)`` mutations stay inside the helper's frame and
27+
never leak SSA values across the persistent tile-scheduler ``while`` loop
28+
in ``run()``. Same isolation pattern the compute role and the FP8 MMA
29+
role (mla_mma_fp8.py) already use.
2230
"""
2331

32+
from typing import Type
33+
2434
import cutlass
2535
import cutlass.cute as cute
2636
import cutlass.cute.nvgpu.tcgen05 as tcgen05
37+
import cutlass.utils.blackwell_helpers as sm100_utils
2738
from cutlass.pipeline import PipelineProducer, PipelineConsumer
2839
from types import SimpleNamespace
2940

@@ -56,6 +67,56 @@ def __init__(self, config: MLAConfig, mainloop: MLAMainloopSpec):
5667
self.iterations_pv_n = config.iterations_pv_n
5768
self.enable_pdl = config.enable_pdl
5869
self.is_var_split_kv = config.is_var_split_kv
70+
self.use_2cta_instrs = config.use_2cta_instrs
71+
self.acc_dtype = config.acc_dtype
72+
# self.q_dtype and self.v_dtype are populated by set_dtypes() — the
73+
# operand element types are only known at __call__ time on the kernel.
74+
75+
def set_dtypes(
76+
self,
77+
q_dtype: Type[cutlass.Numeric],
78+
v_dtype: Type[cutlass.Numeric],
79+
) -> None:
80+
"""Set tensor element types discovered at call time.
81+
82+
Required so the GEMM helpers can reconstruct local TiledMma
83+
instances via ``make_trivial_tiled_mma``.
84+
"""
85+
self.q_dtype: Type[cutlass.Numeric] = q_dtype
86+
self.v_dtype: Type[cutlass.Numeric] = v_dtype
87+
88+
@cute.jit
89+
def _make_local_qk_mma(self) -> cute.TiledMma:
90+
"""Fresh QK TiledMma — mutations on this instance never escape the
91+
helper that constructs it, so the inner k-block loop can use
92+
``range_constexpr`` without leaking SSA values into the enclosing
93+
persistent ``while`` loop in ``run()``."""
94+
cta_group = (
95+
tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
96+
)
97+
return sm100_utils.make_trivial_tiled_mma(
98+
self.q_dtype,
99+
tcgen05.OperandMajorMode.K,
100+
tcgen05.OperandMajorMode.K,
101+
self.acc_dtype,
102+
cta_group,
103+
self.mma_qk_tiler[:2],
104+
)
105+
106+
@cute.jit
107+
def _make_local_pv_mma(self) -> cute.TiledMma:
108+
"""Fresh PV TiledMma — same isolation rationale as ``_make_local_qk_mma``."""
109+
cta_group = (
110+
tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
111+
)
112+
return sm100_utils.make_trivial_tiled_mma(
113+
self.v_dtype,
114+
tcgen05.OperandMajorMode.K,
115+
tcgen05.OperandMajorMode.MN,
116+
self.acc_dtype,
117+
cta_group,
118+
self.mma_pv_tiler[:2],
119+
)
59120

60121
# ------------------------------------------------------------------
61122
# Tile count
@@ -97,30 +158,32 @@ def _get_k_tile_count(
97158
# state back via TiledMma mutations (they would be invisible to the
98159
# caller due to SSA pass-by-value at the @cute.jit boundary).
99160
#
100-
# Inner k-block loops use ``cutlass.range()`` (dynamic scf.for),
101-
# NOT ``cutlass.range_constexpr()`` (compile-time unroll).
102-
# range_constexpr unrolls tiled_mma.set() calls into the enclosing
103-
# scope, producing SSA values that leak across dynamic while-loop
104-
# yields. range() keeps the .set() inside an scf.for scope where
105-
# SSA carry-through is handled correctly.
161+
# Inner k-block loops use ``cutlass.range_constexpr`` (compile-time
162+
# unrolled) for maximum tcgen05 MMA dispatch throughput. To prevent
163+
# the unrolled ``tiled_mma.set(ACCUMULATE, ...)`` mutations from
164+
# leaking SSA values into the enclosing persistent ``while`` loop
165+
# in ``run()`` (which would cause SSA-dominance failures), each
166+
# helper constructs a fresh local TiledMma via
167+
# ``make_trivial_tiled_mma`` and mutates that local instance only.
168+
# The caller's TiledMma is never touched by the helper.
106169
# ------------------------------------------------------------------
107170

108171
@cute.jit
109172
def _gemm_qk_latent_one_stage(
110173
self,
111174
qk_params: SimpleNamespace,
112-
tiled_mma_qk: cute.TiledMma,
113175
s_stage_index: cutlass.Int32,
114176
kv_stage_index: cutlass.Int32,
115177
q_stage: int,
116178
accumulate: bool,
117179
):
118180
"""Compute one QK-latent stage: inner k-block GEMM loop."""
181+
local_mma = self._make_local_qk_mma()
119182
tStS = qk_params.tStS_staged[None, None, None, s_stage_index]
120-
for k_block in cutlass.range(cute.size(qk_params.tSrQ.shape[2])):
121-
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate)
183+
for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])):
184+
local_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate)
122185
cute.gemm(
123-
tiled_mma_qk,
186+
local_mma,
124187
tStS,
125188
qk_params.tSrQ[None, None, k_block, q_stage],
126189
qk_params.tSrKC[None, None, k_block, kv_stage_index],
@@ -131,18 +194,18 @@ def _gemm_qk_latent_one_stage(
131194
def _gemm_qk_rope_one_stage(
132195
self,
133196
qk_params: SimpleNamespace,
134-
tiled_mma_qk: cute.TiledMma,
135197
s_stage_index: cutlass.Int32,
136198
kv_stage_index: cutlass.Int32,
137199
q_stage: int,
138200
accumulate: bool,
139201
):
140202
"""Compute one QK-rope stage: inner k-block GEMM loop."""
203+
local_mma = self._make_local_qk_mma()
141204
tStS = qk_params.tStS_staged[None, None, None, s_stage_index]
142-
for k_block in cutlass.range(self.rope_dim // tiled_mma_qk.shape_mnk[2]):
143-
tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate)
205+
for k_block in cutlass.range_constexpr(self.rope_dim // local_mma.shape_mnk[2]):
206+
local_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate)
144207
cute.gemm(
145-
tiled_mma_qk,
208+
local_mma,
146209
tStS,
147210
qk_params.tSrQ_rope[None, None, k_block, q_stage],
148211
qk_params.tSrKC[None, None, k_block, kv_stage_index],
@@ -153,19 +216,19 @@ def _gemm_qk_rope_one_stage(
153216
def _gemm_pv_one_stage(
154217
self,
155218
pv_params: SimpleNamespace,
156-
tiled_mma_pv: cute.TiledMma,
157219
p_stage_index: cutlass.Int32,
158220
kv_stage_index: cutlass.Int32,
159221
p_stage: int,
160222
acc_stage: int,
161223
accumulate: bool,
162224
):
163225
"""Compute one PV stage: inner k-block GEMM loop."""
226+
local_mma = self._make_local_pv_mma()
164227
tOtO = pv_params.tOtO_staged[None, None, None, acc_stage]
165-
for k_block in cutlass.range(pv_params.tOrP.shape[2]):
166-
tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate)
228+
for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]):
229+
local_mma.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate)
167230
cute.gemm(
168-
tiled_mma_pv,
231+
local_mma,
169232
tOtO,
170233
pv_params.tOrP[
171234
None,
@@ -284,7 +347,6 @@ def run(
284347
kv_handle = load_kv_consumer.wait_and_advance()
285348
self._gemm_qk_latent_one_stage(
286349
mma_qk_params,
287-
tiled_mma_qk,
288350
s_handle.index,
289351
kv_handle.index,
290352
q_stage,
@@ -295,7 +357,6 @@ def run(
295357
kv_handle = load_kv_consumer.wait_and_advance()
296358
self._gemm_qk_rope_one_stage(
297359
mma_qk_params,
298-
tiled_mma_qk,
299360
s_handle.index,
300361
kv_handle.index,
301362
q_stage,
@@ -313,7 +374,6 @@ def run(
313374
kv_handle = load_kv_consumer.wait_and_advance()
314375
self._gemm_qk_latent_one_stage(
315376
mma_qk_params,
316-
tiled_mma_qk,
317377
s_handle.index,
318378
kv_handle.index,
319379
q_stage,
@@ -324,7 +384,6 @@ def run(
324384
kv_handle = load_kv_consumer.wait_and_advance()
325385
self._gemm_qk_rope_one_stage(
326386
mma_qk_params,
327-
tiled_mma_qk,
328387
s_handle.index,
329388
kv_handle.index,
330389
q_stage,
@@ -343,7 +402,6 @@ def run(
343402
kv_handle = load_kv_consumer.wait_and_advance()
344403
self._gemm_pv_one_stage(
345404
mma_pv_params,
346-
tiled_mma_pv,
347405
p_handle.index,
348406
kv_handle.index,
349407
p_stage,
@@ -368,7 +426,6 @@ def run(
368426
kv_handle = load_kv_consumer.wait_and_advance()
369427
self._gemm_pv_one_stage(
370428
mma_pv_params,
371-
tiled_mma_pv,
372429
p_handle.index,
373430
kv_handle.index,
374431
p_stage,

0 commit comments

Comments
 (0)