1919each helper as ``k_block != 0 or accumulate``, making it deterministic from
2020the parameter and inner loop index. Callers compute the parameter from their
2121own loop position — never from ``tiled_mma.get()`` after a sub-method return.
22+
23+ Inner k-block loops use ``cutlass.range_constexpr`` (compile-time unrolled)
24+ for maximum tcgen05 MMA dispatch throughput. Each GEMM helper constructs
25+ its own local ``TiledMma`` via ``make_trivial_tiled_mma`` so that the
26+ ``.set(ACCUMULATE, ...)`` mutations stay inside the helper's frame and
27+ never leak SSA values across the persistent tile-scheduler ``while`` loop
28+ in ``run()``. Same isolation pattern the compute role and the FP8 MMA
29+ role (mla_mma_fp8.py) already use.
2230"""
2331
32+ from typing import Type
33+
2434import cutlass
2535import cutlass .cute as cute
2636import cutlass .cute .nvgpu .tcgen05 as tcgen05
37+ import cutlass .utils .blackwell_helpers as sm100_utils
2738from cutlass .pipeline import PipelineProducer , PipelineConsumer
2839from types import SimpleNamespace
2940
@@ -56,6 +67,56 @@ def __init__(self, config: MLAConfig, mainloop: MLAMainloopSpec):
5667 self .iterations_pv_n = config .iterations_pv_n
5768 self .enable_pdl = config .enable_pdl
5869 self .is_var_split_kv = config .is_var_split_kv
70+ self .use_2cta_instrs = config .use_2cta_instrs
71+ self .acc_dtype = config .acc_dtype
72+ # self.q_dtype and self.v_dtype are populated by set_dtypes() — the
73+ # operand element types are only known at __call__ time on the kernel.
74+
75+ def set_dtypes (
76+ self ,
77+ q_dtype : Type [cutlass .Numeric ],
78+ v_dtype : Type [cutlass .Numeric ],
79+ ) -> None :
80+ """Set tensor element types discovered at call time.
81+
82+ Required so the GEMM helpers can reconstruct local TiledMma
83+ instances via ``make_trivial_tiled_mma``.
84+ """
85+ self .q_dtype : Type [cutlass .Numeric ] = q_dtype
86+ self .v_dtype : Type [cutlass .Numeric ] = v_dtype
87+
88+ @cute .jit
89+ def _make_local_qk_mma (self ) -> cute .TiledMma :
90+ """Fresh QK TiledMma — mutations on this instance never escape the
91+ helper that constructs it, so the inner k-block loop can use
92+ ``range_constexpr`` without leaking SSA values into the enclosing
93+ persistent ``while`` loop in ``run()``."""
94+ cta_group = (
95+ tcgen05 .CtaGroup .TWO if self .use_2cta_instrs else tcgen05 .CtaGroup .ONE
96+ )
97+ return sm100_utils .make_trivial_tiled_mma (
98+ self .q_dtype ,
99+ tcgen05 .OperandMajorMode .K ,
100+ tcgen05 .OperandMajorMode .K ,
101+ self .acc_dtype ,
102+ cta_group ,
103+ self .mma_qk_tiler [:2 ],
104+ )
105+
106+ @cute .jit
107+ def _make_local_pv_mma (self ) -> cute .TiledMma :
108+ """Fresh PV TiledMma — same isolation rationale as ``_make_local_qk_mma``."""
109+ cta_group = (
110+ tcgen05 .CtaGroup .TWO if self .use_2cta_instrs else tcgen05 .CtaGroup .ONE
111+ )
112+ return sm100_utils .make_trivial_tiled_mma (
113+ self .v_dtype ,
114+ tcgen05 .OperandMajorMode .K ,
115+ tcgen05 .OperandMajorMode .MN ,
116+ self .acc_dtype ,
117+ cta_group ,
118+ self .mma_pv_tiler [:2 ],
119+ )
59120
60121 # ------------------------------------------------------------------
61122 # Tile count
@@ -97,30 +158,32 @@ def _get_k_tile_count(
97158 # state back via TiledMma mutations (they would be invisible to the
98159 # caller due to SSA pass-by-value at the @cute.jit boundary).
99160 #
100- # Inner k-block loops use ``cutlass.range()`` (dynamic scf.for),
101- # NOT ``cutlass.range_constexpr()`` (compile-time unroll).
102- # range_constexpr unrolls tiled_mma.set() calls into the enclosing
103- # scope, producing SSA values that leak across dynamic while-loop
104- # yields. range() keeps the .set() inside an scf.for scope where
105- # SSA carry-through is handled correctly.
161+ # Inner k-block loops use ``cutlass.range_constexpr`` (compile-time
162+ # unrolled) for maximum tcgen05 MMA dispatch throughput. To prevent
163+ # the unrolled ``tiled_mma.set(ACCUMULATE, ...)`` mutations from
164+ # leaking SSA values into the enclosing persistent ``while`` loop
165+ # in ``run()`` (which would cause SSA-dominance failures), each
166+ # helper constructs a fresh local TiledMma via
167+ # ``make_trivial_tiled_mma`` and mutates that local instance only.
168+ # The caller's TiledMma is never touched by the helper.
106169 # ------------------------------------------------------------------
107170
108171 @cute .jit
109172 def _gemm_qk_latent_one_stage (
110173 self ,
111174 qk_params : SimpleNamespace ,
112- tiled_mma_qk : cute .TiledMma ,
113175 s_stage_index : cutlass .Int32 ,
114176 kv_stage_index : cutlass .Int32 ,
115177 q_stage : int ,
116178 accumulate : bool ,
117179 ):
118180 """Compute one QK-latent stage: inner k-block GEMM loop."""
181+ local_mma = self ._make_local_qk_mma ()
119182 tStS = qk_params .tStS_staged [None , None , None , s_stage_index ]
120- for k_block in cutlass .range (cute .size (qk_params .tSrQ .shape [2 ])):
121- tiled_mma_qk .set (tcgen05 .Field .ACCUMULATE , k_block != 0 or accumulate )
183+ for k_block in cutlass .range_constexpr (cute .size (qk_params .tSrQ .shape [2 ])):
184+ local_mma .set (tcgen05 .Field .ACCUMULATE , k_block != 0 or accumulate )
122185 cute .gemm (
123- tiled_mma_qk ,
186+ local_mma ,
124187 tStS ,
125188 qk_params .tSrQ [None , None , k_block , q_stage ],
126189 qk_params .tSrKC [None , None , k_block , kv_stage_index ],
@@ -131,18 +194,18 @@ def _gemm_qk_latent_one_stage(
131194 def _gemm_qk_rope_one_stage (
132195 self ,
133196 qk_params : SimpleNamespace ,
134- tiled_mma_qk : cute .TiledMma ,
135197 s_stage_index : cutlass .Int32 ,
136198 kv_stage_index : cutlass .Int32 ,
137199 q_stage : int ,
138200 accumulate : bool ,
139201 ):
140202 """Compute one QK-rope stage: inner k-block GEMM loop."""
203+ local_mma = self ._make_local_qk_mma ()
141204 tStS = qk_params .tStS_staged [None , None , None , s_stage_index ]
142- for k_block in cutlass .range (self .rope_dim // tiled_mma_qk .shape_mnk [2 ]):
143- tiled_mma_qk .set (tcgen05 .Field .ACCUMULATE , k_block != 0 or accumulate )
205+ for k_block in cutlass .range_constexpr (self .rope_dim // local_mma .shape_mnk [2 ]):
206+ local_mma .set (tcgen05 .Field .ACCUMULATE , k_block != 0 or accumulate )
144207 cute .gemm (
145- tiled_mma_qk ,
208+ local_mma ,
146209 tStS ,
147210 qk_params .tSrQ_rope [None , None , k_block , q_stage ],
148211 qk_params .tSrKC [None , None , k_block , kv_stage_index ],
@@ -153,19 +216,19 @@ def _gemm_qk_rope_one_stage(
153216 def _gemm_pv_one_stage (
154217 self ,
155218 pv_params : SimpleNamespace ,
156- tiled_mma_pv : cute .TiledMma ,
157219 p_stage_index : cutlass .Int32 ,
158220 kv_stage_index : cutlass .Int32 ,
159221 p_stage : int ,
160222 acc_stage : int ,
161223 accumulate : bool ,
162224 ):
163225 """Compute one PV stage: inner k-block GEMM loop."""
226+ local_mma = self ._make_local_pv_mma ()
164227 tOtO = pv_params .tOtO_staged [None , None , None , acc_stage ]
165- for k_block in cutlass .range (pv_params .tOrP .shape [2 ]):
166- tiled_mma_pv .set (tcgen05 .Field .ACCUMULATE , k_block != 0 or accumulate )
228+ for k_block in cutlass .range_constexpr (pv_params .tOrP .shape [2 ]):
229+ local_mma .set (tcgen05 .Field .ACCUMULATE , k_block != 0 or accumulate )
167230 cute .gemm (
168- tiled_mma_pv ,
231+ local_mma ,
169232 tOtO ,
170233 pv_params .tOrP [
171234 None ,
@@ -284,7 +347,6 @@ def run(
284347 kv_handle = load_kv_consumer .wait_and_advance ()
285348 self ._gemm_qk_latent_one_stage (
286349 mma_qk_params ,
287- tiled_mma_qk ,
288350 s_handle .index ,
289351 kv_handle .index ,
290352 q_stage ,
@@ -295,7 +357,6 @@ def run(
295357 kv_handle = load_kv_consumer .wait_and_advance ()
296358 self ._gemm_qk_rope_one_stage (
297359 mma_qk_params ,
298- tiled_mma_qk ,
299360 s_handle .index ,
300361 kv_handle .index ,
301362 q_stage ,
@@ -313,7 +374,6 @@ def run(
313374 kv_handle = load_kv_consumer .wait_and_advance ()
314375 self ._gemm_qk_latent_one_stage (
315376 mma_qk_params ,
316- tiled_mma_qk ,
317377 s_handle .index ,
318378 kv_handle .index ,
319379 q_stage ,
@@ -324,7 +384,6 @@ def run(
324384 kv_handle = load_kv_consumer .wait_and_advance ()
325385 self ._gemm_qk_rope_one_stage (
326386 mma_qk_params ,
327- tiled_mma_qk ,
328387 s_handle .index ,
329388 kv_handle .index ,
330389 q_stage ,
@@ -343,7 +402,6 @@ def run(
343402 kv_handle = load_kv_consumer .wait_and_advance ()
344403 self ._gemm_pv_one_stage (
345404 mma_pv_params ,
346- tiled_mma_pv ,
347405 p_handle .index ,
348406 kv_handle .index ,
349407 p_stage ,
@@ -368,7 +426,6 @@ def run(
368426 kv_handle = load_kv_consumer .wait_and_advance ()
369427 self ._gemm_pv_one_stage (
370428 mma_pv_params ,
371- tiled_mma_pv ,
372429 p_handle .index ,
373430 kv_handle .index ,
374431 p_stage ,
0 commit comments