diff --git a/benchmarks/bench_blackwell_attention_cutedsl.py b/benchmarks/bench_blackwell_attention_cutedsl.py new file mode 100644 index 0000000000..05bf0e4e05 --- /dev/null +++ b/benchmarks/bench_blackwell_attention_cutedsl.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import sys + +import numpy as np +import torch + +import flashinfer +from flashinfer.cute_dsl.utils import is_cute_dsl_available +from flashinfer.testing.utils import bench_gpu_time +from flashinfer.utils import is_sm100a_supported + +if not is_cute_dsl_available(): + print("Skipping: nvidia-cutlass-dsl package not installed") + sys.exit(0) + +from flashinfer.cute_dsl.attention import BatchPrefillCuteDSLWrapper + + +def bench_fmha_blackwell( + batch_size, + qkv_len, + num_heads, + head_dim, + causal, + dtype, +): + q = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + k = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + + qo_segment_offsets = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + ) + kv_segment_offsets = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + ) + wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( + torch.empty(128 * 1024 * 1024, dtype=dtype, device="cuda"), + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_segment_offsets, + kv_segment_offsets, + num_heads, + num_heads, + head_dim, + head_dim_vo=head_dim, + causal=causal, + q_data_type=dtype, + kv_data_type=dtype, + ) + o = wrapper.run(q, k, v) + measurements = bench_gpu_time( + lambda: wrapper.run(q, k, v), + dry_run_time_ms=100, + repeat_time_ms=1000, + enable_cupti=True, + ) + ms = np.median(measurements) + + def flops(ms): + if causal: + return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 + else: + return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 + + def io(ms): + mem_size = ( + q.numel() * q.element_size() + + k.numel() * k.element_size() + + v.numel() * v.element_size() + + o.numel() * o.element_size() + ) + return mem_size / ms / 1e6 + + print( + f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s" + ) + + +def bench_fmha_cutedsl( + batch_size, + qkv_len, + num_heads, + head_dim, + causal, + dtype, + sm_scale=None, +): + if sm_scale is None: + sm_scale = 1.0 / (head_dim**0.5) + + q = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + k = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" + ) + + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_heads, + num_heads, + head_dim, + head_dim_vo=head_dim, + causal=causal, + sm_scale=sm_scale, + q_data_type=dtype, + kv_data_type=dtype, + ) + o = wrapper.run(q, k, v) + measurements = bench_gpu_time( + lambda: wrapper.run(q, k, v), + dry_run_time_ms=100, + repeat_time_ms=1000, + enable_cupti=True, + ) + ms = np.median(measurements) + + def flops(ms): + if causal: + return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 + else: + return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 + + def io(ms): + mem_size = ( + q.numel() * q.element_size() + + k.numel() * k.element_size() + + v.numel() * v.element_size() + + o.numel() * o.element_size() + ) + return mem_size / ms / 1e6 + + print( + f"bench_fmha_cutedsl (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s" + ) + + +if __name__ == "__main__": + if not is_sm100a_supported(torch.device("cuda")): + print("Skipping: requires SM100+") + sys.exit(0) + + configs = [ + (128, 512, 32, 128, True, torch.bfloat16), + (64, 1024, 32, 128, True, torch.bfloat16), + (32, 2048, 32, 128, True, torch.bfloat16), + (16, 4096, 32, 128, True, torch.bfloat16), + (8, 8192, 32, 128, True, torch.bfloat16), + (4, 16384, 32, 128, True, torch.bfloat16), + (2, 32768, 32, 128, True, torch.bfloat16), + (1, 65536, 32, 128, True, torch.bfloat16), + ] + + print("=== CUTLASS (via BatchPrefillWithRaggedKVCacheWrapper) ===") + for cfg in configs: + bench_fmha_blackwell(*cfg) + print() + print("=== CuTe DSL (via BatchPrefillCuteDSLWrapper) ===") + for cfg in configs: + bench_fmha_cutedsl(*cfg) diff --git a/flashinfer/cute_dsl/attention/__init__.py b/flashinfer/cute_dsl/attention/__init__.py new file mode 100644 index 0000000000..34b43c4bb9 --- /dev/null +++ b/flashinfer/cute_dsl/attention/__init__.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Modular attention kernels for CuTe DSL. + +Kernels live at the top level of this package. +Building blocks (config, tmem_layout, roles, fusion, scheduler, wrappers) are +one level below in subdirectories. +""" + +# Kernels +from .prefill import BlackwellFusedMultiHeadAttentionForward +from .mla_decode import BlackwellMultiLatentAttentionForward +from .mla_decode_fp8 import BlackwellMultiLatentAttentionForwardFP8 + +# Building blocks — FMHA prefill +from .config import AttentionConfig, AttentionFusion, HeadMapping, TileBounds +from .tmem_layout import TmemLayout +from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE +from .pipeline_topology import ( + PipelineEdge, + PipelineType, + PipelineTopology, + make_prefill_topology, + make_mla_topology, + make_mla_fp8_topology, +) +from .mainloop_spec import ( + MainloopSpec, + make_prefill_mainloop_spec, + MLAMainloopSpec, + make_mla_mainloop_spec, + make_mla_fp8_mainloop_spec, +) +from .fusion.mask import MaskType +from .fusion.variant import ( + tanh_approx, + AttentionVariant, + StandardAttention, + AttentionWithSink, + SigmoidAttention, + SigmoidTanhAttention, + ALiBiAttention, + RPEAttention, + SoftCappingAttention, +) +from .scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, + create_fmha_static_tile_scheduler_params, +) + +# Building blocks — MLA decode +from .mla_config import MLAConfig +from .mla_warp_schedule import ( + MLAWarpSchedule, + MLA_DECODE_SCHEDULE, + MLAWarpScheduleFP8, + MLA_DECODE_FP8_SCHEDULE, +) +from .scheduler.mla_persistent import ( + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler, + create_mla_static_tile_scheduler_params, + mla_get_split_kv, + mla_get_split_kv_simplified, + mla_get_workspace_size, +) + +# Wrappers +from .wrappers.batch_prefill import ( + BatchPrefillCuteDSLWrapper, +) +from .wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + cute_dsl_mla_decode, +) diff --git a/flashinfer/cute_dsl/attention/collective_builder.py b/flashinfer/cute_dsl/attention/collective_builder.py new file mode 100644 index 0000000000..ae8d4be5b2 --- /dev/null +++ b/flashinfer/cute_dsl/attention/collective_builder.py @@ -0,0 +1,930 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""CollectiveBuilder — factory functions for kernel launch infrastructure. + +Analogous to C++ CUTLASS's CollectiveBuilder templates, these functions +select MMA atoms, create SMEM layouts, TMA descriptors, and SharedStorage +structs based on the MainloopSpec and input tensor types. + +This separates "what to compute" (roles, config) from "how to set up +hardware" (MMA atoms, TMA, shared memory), keeping the kernel __call__ +focused on wiring and launch. +""" + +from types import SimpleNamespace + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int64 + +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.cute.nvgpu.cpasync as cpasync + +from .mainloop_spec import MainloopSpec, MLAMainloopSpec +from .mla_warp_schedule import MLAWarpSchedule, MLAWarpScheduleFP8 + + +def build_fmha_launch_params( + mainloop: MainloopSpec, + q: cute.Tensor, + k: cute.Tensor, + v: cute.Tensor, + o: cute.Tensor, + q_dtype, + k_dtype, + v_dtype, + o_dtype, + q_major_mode, + k_major_mode, + v_major_mode, + o_layout, +) -> SimpleNamespace: + """Build all MMA atoms, SMEM layouts, TMA atoms, and SharedStorage for FMHA prefill. + + :param mainloop: Resolved MainloopSpec (resolve() must have been called). + :param q: Query tensor. + :param k: Key tensor. + :param v: Value tensor. + :param o: Output tensor. + :param q_dtype: Element type of Q. + :param k_dtype: Element type of K. + :param v_dtype: Element type of V. + :param o_dtype: Element type of O. + :param q_major_mode: MMA major mode for Q operand. + :param k_major_mode: MMA major mode for K operand. + :param v_major_mode: MMA major mode for V operand. + :param o_layout: Layout enum for output. + :returns: SimpleNamespace with all derived objects needed for kernel launch. + """ + config = mainloop.config + + cta_group = tcgen05.CtaGroup.ONE + p_major_mode = tcgen05.OperandMajorMode.K + p_source = tcgen05.OperandSource.TMEM + + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + q_dtype, + q_major_mode, + k_major_mode, + config.qk_acc_dtype, + cta_group, + config.qk_mma_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + v_dtype, + p_major_mode, + v_major_mode, + config.pv_acc_dtype, + cta_group, + config.pv_mma_tiler[:2], + p_source, + ) + + cluster_shape_mnk = (*config.cluster_shape_mn, 1) + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + epi_tile = config.pv_mma_tiler[:2] + + # SMEM layouts + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + config.qk_mma_tiler, + q_dtype, + mainloop.q_stages, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + config.qk_mma_tiler, + k_dtype, + mainloop.kv_stages, + ) + p_tmem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + config.pv_mma_tiler, + q_dtype, + mainloop.acc_stage, + ) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + config.pv_mma_tiler, + v_dtype, + mainloop.kv_stages, + ) + o_smem_layout_staged = sm100_utils.make_smem_layout_epi( + o_dtype, + o_layout, + epi_tile, + mainloop.epi_stage, + ) + + # TMA atoms + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + config.qk_mma_tiler, + qk_tiled_mma, + cluster_layout_vmnk.shape, + ) + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + config.qk_mma_tiler, + qk_tiled_mma, + cluster_layout_vmnk.shape, + ) + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + config.pv_mma_tiler, + pv_tiled_mma, + cluster_layout_vmnk.shape, + ) + o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) + tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_store_op, + o, + o_smem_layout, + epi_tile, + ) + + tma_copy_q_bytes = cute.size_in_bytes(q_dtype, q_smem_layout) + tma_copy_kv_bytes = cute.size_in_bytes(k_dtype, k_smem_layout) + + # SharedStorage struct + align = mainloop.buffer_align_bytes + sched = mainloop.warp_schedule + + # Minimize barrier storage for unused paths + s0_corr_stages = ( + mainloop.softmax_corr_stage if not mainloop.has_logits_transform else 1 + ) + mma_corr_stages = ( + mainloop.mma_corr_stage if not mainloop.has_logits_transform else 1 + ) + s0_epi_stages = mainloop.epi_stage if mainloop.has_logits_transform else 1 + + @cute.struct + class SharedStorage: + load_q_mbar_ptr: cute.struct.MemRange[Int64, mainloop.q_stages * 2] + load_kv_mbar_ptr: cute.struct.MemRange[Int64, mainloop.kv_stages * 2] + mma_s0_mbar_ptr: cute.struct.MemRange[Int64, mainloop.mma_softmax_stage * 2] + mma_s1_mbar_ptr: cute.struct.MemRange[Int64, mainloop.mma_softmax_stage * 2] + s0_corr_mbar_ptr: cute.struct.MemRange[Int64, s0_corr_stages * 2] + s1_corr_mbar_ptr: cute.struct.MemRange[Int64, s0_corr_stages * 2] + s0_s1_sequence_mbar_ptr: cute.struct.MemRange[ + Int64, sched.softmax_warpgroup_count + ] + corr_epi_mbar_ptr: cute.struct.MemRange[Int64, mainloop.epi_stage * 2] + mma_corr_mbar_ptr: cute.struct.MemRange[Int64, mma_corr_stages * 2] + s0_epi_mbar_ptr: cute.struct.MemRange[Int64, s0_epi_stages * 2] + s1_epi_mbar_ptr: cute.struct.MemRange[Int64, s0_epi_stages * 2] + tmem_dealloc_mbar_ptr: cute.struct.MemRange[Int64, 1] + tmem_holding_buf: cutlass.Int32 + sO: cute.struct.Align[ + cute.struct.MemRange[o_dtype, cute.cosize(o_smem_layout_staged)], + align, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(q_smem_layout_staged)], + align, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[k_dtype, cute.cosize(k_smem_layout_staged)], + align, + ] + + @classmethod + def size_in_bytes(cls) -> int: ... # noqa: F811 + + return SimpleNamespace( + qk_tiled_mma=qk_tiled_mma, + pv_tiled_mma=pv_tiled_mma, + tma_atom_q=tma_atom_q, + tma_tensor_q=tma_tensor_q, + tma_atom_k=tma_atom_k, + tma_tensor_k=tma_tensor_k, + tma_atom_v=tma_atom_v, + tma_tensor_v=tma_tensor_v, + tma_atom_o=tma_atom_o, + tma_tensor_o=tma_tensor_o, + q_smem_layout_staged=q_smem_layout_staged, + k_smem_layout_staged=k_smem_layout_staged, + p_tmem_layout_staged=p_tmem_layout_staged, + v_smem_layout_staged=v_smem_layout_staged, + o_smem_layout_staged=o_smem_layout_staged, + SharedStorage=SharedStorage, + tma_copy_q_bytes=tma_copy_q_bytes, + tma_copy_kv_bytes=tma_copy_kv_bytes, + cluster_shape_mnk=cluster_shape_mnk, + cluster_layout_vmnk=cluster_layout_vmnk, + epi_tile=epi_tile, + o_layout=o_layout, + ) + + +def make_paged_tiled_tma_atom( + tma_load_op: cpasync.CopyBulkTensorTileG2SOp, + gmem: cute.Tensor, + smem_layout: cute.Layout, + mma_tiler, + tiled_mma: cute.TiledMma, + page_size: int, + is_k_load: bool, +): + """Create a paged TMA atom for tiled memory access with page table indirection. + + Extracted from the monolithic MLA kernel's make_paged_tiled_tma_atom method. + Builds a non-executable TMA descriptor that tiles the global memory tensor + into page-aligned chunks for paged KV cache access. + + :param tma_load_op: TMA copy operation (G2S with CTA group). + :param gmem: Global memory tensor to create the TMA descriptor for. + :param smem_layout: Shared memory layout for the TMA tile. + :param mma_tiler: MMA tile dimensions (M, K) or (M, N). + :param tiled_mma: The TiledMma atom used for CTA partitioning. + :param page_size: Number of tokens per page in the KV cache. + :param is_k_load: True for K-operand loads, False for V-operand loads. + :returns: Tuple of (CopyAtom, TMA tensor descriptor). + """ + ident = cute.make_identity_layout(gmem.shape) + g_tile = cute.composition(ident, mma_tiler) + cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape + cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) + cta_v_map = cute.select(cta_v_map, mode=[0, 2]) + page_tile_size = ( + min(page_size, cta_mn) if is_k_load else min(page_size, mma_tiler[1]) + ) + cta_v_map = cute.zipped_divide( + cta_v_map, + (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), + ) + cta_v_map = cute.select(cta_v_map, mode=[0]) + from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem.value, + smem_layout.value, + cta_v_map, + tma_load_op._to_ir(), + num_multicast=1, + ) + return cute.CopyAtom( + tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) + ), res[1] + + +def build_mla_launch_params( + mainloop: MLAMainloopSpec, + schedule: MLAWarpSchedule, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + c_latent_transpose: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + acc_o: cute.Tensor, + acc_lse: cute.Tensor, + q_dtype, + k_dtype, + v_dtype, + o_dtype, +) -> SimpleNamespace: + """Build all MMA atoms, SMEM layouts, TMA atoms, and SharedStorage for MLA decode. + + :param mainloop: Resolved MLAMainloopSpec. + :param schedule: MLAWarpSchedule with warp role assignments. + :param q_latent: Query latent tensor (reinterpreted as [H, D, S_q, B]). + :param q_rope: Query RoPE tensor (reinterpreted as [H, D, S_q, B]). + :param c_latent: KV latent tensor (reinterpreted as [page_size, D, num_pages]). + :param c_rope: KV RoPE tensor (reinterpreted as [page_size, D, num_pages]). + :param c_latent_transpose: Transposed KV latent (reinterpreted as [D, page_size, num_pages]). + :param page_table: Page table tensor. + :param o: Output tensor. + :param lse: LSE tensor. + :param acc_o: Accumulator output tensor (for split-KV). + :param acc_lse: Accumulator LSE tensor (for split-KV). + :param q_dtype: Element type of Q. + :param k_dtype: Element type of K. + :param v_dtype: Element type of V. + :param o_dtype: Element type of O. + :returns: SimpleNamespace with all derived objects needed for kernel launch. + """ + config = mainloop.config + + cta_group = tcgen05.CtaGroup.TWO + q_major_mode = tcgen05.OperandMajorMode.K + k_major_mode = tcgen05.OperandMajorMode.K + v_major_mode = tcgen05.OperandMajorMode.MN + p_major_mode = tcgen05.OperandMajorMode.K + + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + q_dtype, + q_major_mode, + k_major_mode, + config.acc_dtype, + cta_group, + config.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + v_dtype, + p_major_mode, + v_major_mode, + config.acc_dtype, + cta_group, + config.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(config.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + epi_tile = config.mma_pv_tiler[:2] + + # SMEM layouts + q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + config.mma_qk_tiler, + q_dtype, + (config.iterations_qk_latent * config.load_q_stage), + ) + q_latent_smem_layout_staged = cute.logical_divide( + q_latent_smem_layout_staged, + (None, None, None, config.iterations_qk_latent), + ) + q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + config.mma_qk_rope_tiler, + q_dtype, + config.load_q_stage, + ) + + kc_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + config.mma_qk_tiler, + k_dtype, + config.load_kv_stage, + ) + kc_page_tile_size = min( + config.page_size, + qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape, + ) + kc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + (config.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, config.mma_qk_tiler[2]), + k_dtype, + config.load_kv_stage, + ) + kc_smem_layout_for_tma = cute.tiled_divide( + kc_smem_layout_for_tma, (kc_page_tile_size, config.mma_qk_tiler[2]) + ) + + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + config.mma_pv_tiler, + q_dtype, + (config.iterations_pv_k * config.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, (None, None, None, config.iterations_pv_k) + ) + + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + config.mma_pv_tiler, + v_dtype, + config.load_kv_stage, + ) + vc_page_tile_size = min(config.page_size, config.mma_pv_tiler[2]) + vc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.MN, + (config.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, config.mma_pv_tiler[2]), + v_dtype, + config.load_kv_stage, + ) + vc_smem_layout_for_tma = cute.tiled_divide( + vc_smem_layout_for_tma, + ( + pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, + vc_page_tile_size, + ), + ) + + # TMA atoms + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_latent_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_latent_smem_layout, + config.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_rope_smem_layout, + config.mma_qk_rope_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + + kc_smem_layout = cute.select(kc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent, tma_tensor_c_latent = make_paged_tiled_tma_atom( + tma_load_op, + c_latent, + kc_smem_layout, + (config.mma_qk_tiler[1], config.mma_qk_tiler[2]), + qk_tiled_mma, + config.page_size, + is_k_load=True, + ) + tma_atom_c_rope, tma_tensor_c_rope = make_paged_tiled_tma_atom( + tma_load_op, + c_rope, + kc_smem_layout, + (config.mma_qk_tiler[1], config.mma_qk_tiler[2]), + qk_tiled_mma, + config.page_size, + is_k_load=True, + ) + + vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + make_paged_tiled_tma_atom( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + (config.mma_pv_tiler[1], config.mma_pv_tiler[2]), + pv_tiled_mma, + config.page_size, + is_k_load=False, + ) + ) + + # Copy sizes + q_latent_copy_size = ( + cute.size_in_bytes(q_dtype, q_latent_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * config.iterations_qk_latent + ) + q_rope_copy_size = ( + cute.size_in_bytes(q_dtype, q_rope_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * config.iterations_qk_rope + ) + tma_copy_q_bytes = q_latent_copy_size + q_rope_copy_size + tma_copy_kc_bytes = cute.size_in_bytes( + k_dtype, cute.select(kc_smem_layout_staged, mode=[0, 1, 2]) + ) * cute.size(qk_tiled_mma.thr_id.shape) + + # SharedStorage struct + align = mainloop.buffer_align_bytes + threads_per_warp = schedule.threads_per_warp + num_compute_warps = config.num_compute_warps + + @cute.struct + class SplitKVKernelSharedStorage: + load_q_mbar_ptr: cute.struct.MemRange[Int64, config.load_q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[Int64, config.load_kv_stage * 2] + mma_s_mbar_ptr: cute.struct.MemRange[Int64, config.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[Int64, config.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[Int64, config.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[Int64, config.mma_o_stage * 2] + load_pt_mbar_ptr: cute.struct.MemRange[Int64, config.load_pt_stage * 2] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: cutlass.Int32 + softmax_smem_exchange: cute.struct.MemRange[ + config.acc_dtype, num_compute_warps * threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + config.acc_dtype, num_compute_warps * threads_per_warp + ] + smem_q_latent: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(q_latent_smem_layout_staged)], + align, + ] + smem_q_rope: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(q_rope_smem_layout_staged)], + align, + ] + smem_kc: cute.struct.Align[ + cute.struct.MemRange[k_dtype, cute.cosize(kc_smem_layout_staged)], + align, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(p_smem_layout_staged)], + align, + ] + smem_page_table: cute.struct.MemRange[ + cutlass.Int32, config.load_pt_stage * config.mma_qk_tiler[1] // 2 + ] + + @classmethod + def size_in_bytes(cls) -> int: ... # noqa: F811 + + return SimpleNamespace( + qk_tiled_mma=qk_tiled_mma, + pv_tiled_mma=pv_tiled_mma, + q_latent_smem_layout_staged=q_latent_smem_layout_staged, + q_rope_smem_layout_staged=q_rope_smem_layout_staged, + kc_smem_layout_staged=kc_smem_layout_staged, + p_smem_layout_staged=p_smem_layout_staged, + vc_smem_layout_staged=vc_smem_layout_staged, + kc_smem_layout_for_tma=kc_smem_layout_for_tma, + vc_smem_layout_for_tma=vc_smem_layout_for_tma, + tma_atom_q_latent=tma_atom_q_latent, + tma_tensor_q_latent=tma_tensor_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_tensor_q_rope=tma_tensor_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_tensor_c_latent=tma_tensor_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + tma_tensor_c_rope=tma_tensor_c_rope, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose=tma_tensor_c_latent_transpose, + kc_page_tile_size=kc_page_tile_size, + vc_page_tile_size=vc_page_tile_size, + tma_copy_q_bytes=tma_copy_q_bytes, + tma_copy_kc_bytes=tma_copy_kc_bytes, + SharedStorage=SplitKVKernelSharedStorage, + cta_layout_vmnk=cta_layout_vmnk, + epi_tile=epi_tile, + cluster_shape_mnk=config.cluster_shape_mnk, + ) + + +def build_mla_fp8_launch_params( + mainloop: MLAMainloopSpec, + schedule: MLAWarpScheduleFP8, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + c_latent_transpose: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + acc_o: cute.Tensor, + acc_lse: cute.Tensor, + q_dtype, + k_dtype, + v_dtype, + o_dtype, +) -> SimpleNamespace: + """Build MMA atoms, SMEM layouts, TMA atoms, and SharedStorage for FP8 MLA decode. + + FP8 differs from FP16 in: + - Separate KC-latent, KC-rope, and VC SMEM buffers (no aliasing) + - KC-latent stages use logical_divide for iterations_qk_latent + - VC stages use nested logical_divide for iterations_pv_k * iterations_pv_n + - No page-table SMEM buffer or load_pt barriers + - Separate tma_copy_kc_bytes and tma_copy_vc_bytes + """ + config = mainloop.config + + cta_group = tcgen05.CtaGroup.TWO + q_major_mode = OperandMajorMode.K + k_major_mode = OperandMajorMode.K + v_major_mode = OperandMajorMode.MN + p_major_mode = OperandMajorMode.K + + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + q_dtype, + q_major_mode, + k_major_mode, + config.acc_dtype, + cta_group, + config.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + v_dtype, + p_major_mode, + v_major_mode, + config.acc_dtype, + cta_group, + config.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(config.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + epi_tile = config.mma_pv_tiler[:2] + + # --- Q SMEM layouts (same structure as FP16) --- + q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + config.mma_qk_tiler, + q_dtype, + (config.iterations_qk_latent * config.load_q_stage), + ) + q_latent_smem_layout_staged = cute.logical_divide( + q_latent_smem_layout_staged, + (None, None, None, config.iterations_qk_latent), + ) + q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + config.mma_qk_rope_tiler, + q_dtype, + config.load_q_stage, + ) + + # --- KC-latent SMEM: separate buffer with logical_divide for latent iterations --- + kc_latent_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + config.mma_qk_tiler, + k_dtype, + (config.iterations_qk_latent * config.load_k_stage), + ) + kc_page_tile_size = min( + config.page_size, + qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape, + ) + kc_latent_smem_layout_staged = cute.logical_divide( + kc_latent_smem_layout_staged, + (None, None, None, config.iterations_qk_latent), + ) + + kc_latent_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + (config.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, config.mma_qk_tiler[2]), + k_dtype, + (config.iterations_qk_latent * config.load_k_stage), + ) + kc_latent_smem_layout_for_tma = cute.tiled_divide( + kc_latent_smem_layout_for_tma, + (kc_page_tile_size, config.mma_qk_tiler[2]), + ) + kc_latent_smem_layout_for_tma = cute.logical_divide( + kc_latent_smem_layout_for_tma, + (None, None, None, config.iterations_qk_latent), + ) + + # --- KC-rope SMEM: separate buffer --- + kc_rope_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + config.mma_qk_rope_tiler, + k_dtype, + config.load_k_stage, + ) + kc_rope_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + ( + config.mma_qk_rope_tiler[0] // qk_tiled_mma.thr_id.shape, + config.mma_qk_rope_tiler[2], + ), + k_dtype, + (config.iterations_qk_rope * config.load_k_stage), + ) + kc_rope_smem_layout_for_tma = cute.tiled_divide( + kc_rope_smem_layout_for_tma, + (kc_page_tile_size, config.mma_qk_rope_tiler[2]), + ) + + # --- P SMEM layout --- + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + config.mma_pv_tiler, + q_dtype, + (config.iterations_pv_k * config.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, + (None, None, None, config.iterations_pv_k), + ) + + # --- VC SMEM: separate buffer with nested logical_divide --- + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + config.mma_pv_tiler, + v_dtype, + (config.iterations_pv_k * config.iterations_pv_n * config.load_v_stage), + ) + vc_smem_layout_staged = cute.logical_divide( + cute.logical_divide( + vc_smem_layout_staged, + (None, None, None, config.iterations_pv_k * config.iterations_pv_n), + ), + (None, None, None, (config.iterations_pv_n, None)), + ) + vc_page_tile_size = min(config.page_size, config.mma_pv_tiler[2]) + vc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.MN, + (config.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, config.mma_pv_tiler[2]), + v_dtype, + (config.iterations_pv_k * config.iterations_pv_n * config.load_v_stage), + ) + vc_smem_layout_for_tma = cute.tiled_divide( + vc_smem_layout_for_tma, + ( + pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, + vc_page_tile_size, + ), + ) + vc_smem_layout_for_tma = cute.logical_divide( + cute.logical_divide( + vc_smem_layout_for_tma, + (None, None, None, config.iterations_pv_k * config.iterations_pv_n), + ), + (None, None, None, (config.iterations_pv_n, None)), + ) + + # --- TMA atoms --- + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_smem_layout, + config.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_rope_smem_layout, + config.mma_qk_rope_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + + kc_smem_layout = cute.select(kc_latent_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent, tma_tensor_c_latent = make_paged_tiled_tma_atom( + tma_load_op, + c_latent, + kc_smem_layout, + (config.mma_qk_tiler[1], config.mma_qk_tiler[2]), + qk_tiled_mma, + config.page_size, + is_k_load=True, + ) + kc_rope_smem_layout = cute.select(kc_rope_smem_layout_for_tma, mode=[0]) + tma_atom_c_rope, tma_tensor_c_rope = make_paged_tiled_tma_atom( + tma_load_op, + c_rope, + kc_rope_smem_layout, + (config.mma_qk_rope_tiler[1], config.mma_qk_rope_tiler[2]), + qk_tiled_mma, + config.page_size, + is_k_load=True, + ) + + vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + make_paged_tiled_tma_atom( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + (config.mma_pv_tiler[1], config.mma_pv_tiler[2]), + pv_tiled_mma, + config.page_size, + is_k_load=False, + ) + ) + + # --- Copy sizes --- + q_latent_copy_size = ( + cute.size_in_bytes(q_dtype, q_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * config.iterations_qk_latent + ) + q_rope_copy_size = ( + cute.size_in_bytes(q_dtype, q_rope_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * config.iterations_qk_rope + ) + tma_copy_q_bytes = q_latent_copy_size + q_rope_copy_size + + kc_latent_copy_size = ( + cute.size_in_bytes( + k_dtype, + cute.select(kc_latent_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(qk_tiled_mma.thr_id.shape) + * config.iterations_qk_latent + ) + kc_rope_copy_size = ( + cute.size_in_bytes( + k_dtype, + cute.select(kc_rope_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(qk_tiled_mma.thr_id.shape) + * config.iterations_qk_rope + ) + tma_copy_kc_bytes = kc_latent_copy_size + kc_rope_copy_size + + tma_copy_vc_bytes = ( + cute.size_in_bytes( + v_dtype, + cute.select(vc_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(pv_tiled_mma.thr_id.shape) + * config.iterations_pv_n + * config.iterations_pv_k + ) + + # --- SharedStorage struct (no page-table buffer) --- + align = mainloop.buffer_align_bytes + threads_per_warp = schedule.threads_per_warp + num_compute_warps = config.num_compute_warps + + @cute.struct + class FP8SplitKVKernelSharedStorage: + load_q_mbar_ptr: cute.struct.MemRange[Int64, config.load_q_stage * 2] + load_k_mbar_ptr: cute.struct.MemRange[Int64, config.load_k_stage * 2] + load_v_mbar_ptr: cute.struct.MemRange[Int64, config.load_v_stage * 2] + mma_s_mbar_ptr: cute.struct.MemRange[Int64, config.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[Int64, config.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[Int64, config.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[Int64, config.mma_o_stage * 2] + + smem_p: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(p_smem_layout_staged)], + align, + ] + smem_kc_latent: cute.struct.Align[ + cute.struct.MemRange[k_dtype, cute.cosize(kc_latent_smem_layout_staged)], + align, + ] + smem_kc_rope: cute.struct.Align[ + cute.struct.MemRange[k_dtype, cute.cosize(kc_rope_smem_layout_staged)], + align, + ] + smem_q_latent: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(q_latent_smem_layout_staged)], + align, + ] + smem_q_rope: cute.struct.Align[ + cute.struct.MemRange[q_dtype, cute.cosize(q_rope_smem_layout_staged)], + align, + ] + smem_vc: cute.struct.Align[ + cute.struct.MemRange[v_dtype, cute.cosize(vc_smem_layout_staged)], + align, + ] + softmax_smem_exchange: cute.struct.MemRange[ + config.acc_dtype, num_compute_warps * threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + config.acc_dtype, num_compute_warps * threads_per_warp + ] + tmem_dealloc_mbar_ptr: Int64 + tmem_holding_buf: cutlass.Int32 + + @classmethod + def size_in_bytes(cls) -> int: ... # noqa: F811 + + return SimpleNamespace( + qk_tiled_mma=qk_tiled_mma, + pv_tiled_mma=pv_tiled_mma, + q_latent_smem_layout_staged=q_latent_smem_layout_staged, + q_rope_smem_layout_staged=q_rope_smem_layout_staged, + kc_latent_smem_layout_staged=kc_latent_smem_layout_staged, + kc_rope_smem_layout_staged=kc_rope_smem_layout_staged, + p_smem_layout_staged=p_smem_layout_staged, + vc_smem_layout_staged=vc_smem_layout_staged, + kc_latent_smem_layout_for_tma=kc_latent_smem_layout_for_tma, + kc_rope_smem_layout_for_tma=kc_rope_smem_layout_for_tma, + vc_smem_layout_for_tma=vc_smem_layout_for_tma, + tma_atom_q_latent=tma_atom_q_latent, + tma_tensor_q_latent=tma_tensor_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_tensor_q_rope=tma_tensor_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_tensor_c_latent=tma_tensor_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + tma_tensor_c_rope=tma_tensor_c_rope, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose=tma_tensor_c_latent_transpose, + kc_page_tile_size=kc_page_tile_size, + vc_page_tile_size=vc_page_tile_size, + tma_copy_q_bytes=tma_copy_q_bytes, + tma_copy_kc_bytes=tma_copy_kc_bytes, + tma_copy_vc_bytes=tma_copy_vc_bytes, + SharedStorage=FP8SplitKVKernelSharedStorage, + cta_layout_vmnk=cta_layout_vmnk, + epi_tile=epi_tile, + cluster_shape_mnk=config.cluster_shape_mnk, + ) diff --git a/flashinfer/cute_dsl/attention/compat.py b/flashinfer/cute_dsl/attention/compat.py new file mode 100644 index 0000000000..7358f1ead4 --- /dev/null +++ b/flashinfer/cute_dsl/attention/compat.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Compatibility shims for cutlass-dsl version differences. + +Centralizes version-dependent API lookups so kernel and role files don't +each carry their own copies. +""" + +import cutlass.cute as cute + + +# setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; +# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. +setmaxregister_decrease = getattr( + cute.arch, + "setmaxregister_decrease", + getattr(cute.arch, "warpgroup_reg_dealloc", None), +) + +setmaxregister_increase = getattr( + cute.arch, + "setmaxregister_increase", + getattr(cute.arch, "warpgroup_reg_alloc", None), +) + +# get_max_tmem_alloc_cols added in cutlass-dsl 4.4; +# older versions don't have it. +_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} + + +def get_max_tmem_alloc_cols(compute_capability: str) -> int: + if hasattr(cute.arch, "get_max_tmem_alloc_cols"): + return cute.arch.get_max_tmem_alloc_cols(compute_capability) + if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] diff --git a/flashinfer/cute_dsl/attention/config.py b/flashinfer/cute_dsl/attention/config.py new file mode 100644 index 0000000000..003a90ea2a --- /dev/null +++ b/flashinfer/cute_dsl/attention/config.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""AttentionConfig and AttentionFusion — single source of truth for attention kernel parameters. + +AttentionConfig holds all the configuration needed by the kernel: dtypes, tile shapes, +execution mode, and feature flags. Derived properties (cta_tiler, pv_mma_tiler) are +computed from the base parameters. + +AttentionFusion bundles an AttentionVariant (the customization point for logits +transform, softmax statistics, and output normalization) into a single object +that the kernel consumes. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass +from typing import Tuple, Any + +from .fusion.mask import MaskType +from .fusion.variant import AttentionVariant, StandardAttention + + +class HeadMapping(enum.Enum): + """How attention heads map to MMA tile dimensions. + + GRID: Heads in grid/loop dimension (prefill, training). Any head count works. + MMA_M: All heads packed into MMA M-dimension (decode). + MMA_N: GQA group packed into MMA N-dimension (standard GQA decode). + """ + + GRID = "grid" + MMA_M = "mma_m" + MMA_N = "mma_n" + + +@dataclass +class TileBounds: + """Handles partial MMA tile filling when logical data < physical tile size. + + Reserved for decode — unused by prefill. Most critical for decode where + num_heads < 128 (the MMA M-tile width). Compiles away via + cutlass.const_expr when no masking is needed. + """ + + m_bound: int | None = None + n_bound: int | None = None + + def needs_m_masking(self, tile_m: int) -> bool: + return self.m_bound is not None and self.m_bound < tile_m + + def needs_n_masking(self, tile_n: int) -> bool: + return self.n_bound is not None and self.n_bound < tile_n + + +@dataclass +class AttentionConfig: + """Single source of truth for attention kernel configuration. + + Replaces the scattered self.xxx attributes in the kernel's __init__. + Derived properties (cta_tiler, pv_mma_tiler, etc.) are computed from + the base parameters. + """ + + # Core parameters + qk_acc_dtype: Any # Type[cutlass.Numeric] — using Any to avoid import dependency + pv_acc_dtype: Any + mma_tiler: Tuple[int, int, int] + is_persistent: bool + mask_type: MaskType + num_repeat_kv_heads: int = 1 + window_left: int = -1 + + # Reserved for decode — unused by prefill. Decode kernels pack heads into + # MMA M/N dimensions; prefill maps heads via the grid (HeadMapping.GRID). + head_mapping: HeadMapping = HeadMapping.GRID + num_heads: int = 0 + num_kv_heads: int = 0 + + SUPPORTED_MMA_TILE_MN = (128, 128) + MMA_K_GRANULARITY = { + 16: 16, + 8: 32, + } # {dtype_width_bits: K-tile element granularity} + + def can_implement(self, dtype_width: int = 16) -> None: + """Validate that this config is implementable on Blackwell SM100. + + Checks hardware-level constraints that are independent of the target GPU's + SMEM capacity. SMEM overruns are caught at kernel launch time by CUDA. + + :param dtype_width: Bit width of the input element type (16 for fp16/bf16, 8 for fp8). + :raises ValueError: If validation fails, with a descriptive message. + """ + mma_mn = self.mma_tiler[:2] + if mma_mn != self.SUPPORTED_MMA_TILE_MN: + raise ValueError( + f"mma_tiler_mn={mma_mn} is not supported. " + f"Must be {self.SUPPORTED_MMA_TILE_MN} for Blackwell SM100 tcgen05" + ) + head_dim = self.mma_tiler[2] + k_gran = self.MMA_K_GRANULARITY.get(dtype_width, 16) + if head_dim == 0 or head_dim % k_gran != 0: + raise ValueError( + f"head_dim={head_dim} must be a positive multiple of {k_gran} " + f"(MMA K-dimension granularity for {dtype_width}-bit dtype)" + ) + if self.num_repeat_kv_heads < 1: + raise ValueError( + f"num_repeat_kv_heads={self.num_repeat_kv_heads} must be >= 1" + ) + + @property + def cta_tiler(self) -> Tuple[int, int, int]: + """CTA tile: 2 Q tiles per CTA in M-dimension.""" + return ( + 2 * self.mma_tiler[0], + self.mma_tiler[1], + self.mma_tiler[2], + ) + + @property + def qk_mma_tiler(self) -> Tuple[int, int, int]: + """MMA tile for Q*K^T computation.""" + return self.mma_tiler + + @property + def pv_mma_tiler(self) -> Tuple[int, int, int]: + """MMA tile for P*V computation (transposed).""" + return ( + self.mma_tiler[0], + self.mma_tiler[2], + self.mma_tiler[1], + ) + + @property + def cluster_shape_mn(self) -> Tuple[int, int]: + """Cluster shape (always (1,1) for prefill).""" + return (1, 1) + + @property + def tile_bounds(self) -> TileBounds: + """Derive tile bounds from head mapping.""" + if self.head_mapping == HeadMapping.MMA_M and self.num_heads > 0: + return TileBounds(m_bound=self.num_heads) + return TileBounds() + + +@dataclass +class AttentionFusion: + """Bundles an AttentionVariant with the kernel. + + The variant object defines all customization hooks (logits transform, + statistics update, output transform) as co-defined methods. Compile-time + flags on the variant drive dead-code elimination via ``cutlass.const_expr``. + + See :class:`~flashinfer.cute_dsl.attention.fusion.variant.AttentionVariant` + for the full API and execution-order documentation. + """ + + variant: AttentionVariant = None # type: ignore[assignment] + + def __post_init__(self): + if self.variant is None: + self.variant = StandardAttention() + + @property + def has_params(self) -> bool: + """Whether the variant needs runtime tensor data.""" + return self.variant.extra_params is not None + + @property + def params_shape(self) -> tuple | None: + """Shape of the variant's runtime tensor, or None.""" + ep = self.variant.extra_params + return tuple(ep.shape) if ep is not None else None + + @property + def params_strides(self) -> tuple | None: + """Element strides of the variant's runtime tensor, or None. + + Derived from the PyTorch tensor's actual strides so the CuTe layout + in the kernel matches the source memory layout. CuTe defaults to + column-major; PyTorch is row-major — using explicit strides avoids + a silent layout mismatch. + """ + ep = self.variant.extra_params + return tuple(ep.stride()) if ep is not None else None diff --git a/flashinfer/cute_dsl/attention/fusion/__init__.py b/flashinfer/cute_dsl/attention/fusion/__init__.py new file mode 100644 index 0000000000..4997a0e8c4 --- /dev/null +++ b/flashinfer/cute_dsl/attention/fusion/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from .mask import ( + MaskType, + apply_mask, + get_trip_count, + get_masked_trip_count, + get_unmasked_trip_count, + get_kv_start_block_idx, +) +from .variant import ( + tanh_approx, + AttentionVariant, + StandardAttention, + AttentionWithSink, + SigmoidAttention, + SigmoidTanhAttention, + ALiBiAttention, + RPEAttention, + SoftCappingAttention, +) diff --git a/flashinfer/cute_dsl/attention/fusion/mask.py b/flashinfer/cute_dsl/attention/fusion/mask.py new file mode 100644 index 0000000000..b16b57827e --- /dev/null +++ b/flashinfer/cute_dsl/attention/fusion/mask.py @@ -0,0 +1,177 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Mask types and masking helper functions for attention kernels. + +All helpers are standalone @cute.jit functions that take mask_type and +window_left as compile-time parameters, so they can be reused across +different kernel variants (prefill, decode). +""" + +import enum + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32, Float32 + + +class MaskType(enum.Enum): + NO_MASK = enum.auto() + RESIDUAL_MASK = enum.auto() + CAUSAL_MASK = enum.auto() + SLIDING_WINDOW_MASK = enum.auto() + + +@cute.jit +def get_trip_count( + mask_type: MaskType, + window_left: int, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + seqlen_q: Int32 = 0, +) -> Int32: + """Number of KV tile blocks to process for this Q tile.""" + result = 0 + if mask_type == MaskType.NO_MASK or mask_type == MaskType.RESIDUAL_MASK: + result = cute.ceil_div(seqlen_k, tile_shape[1]) + elif mask_type == MaskType.CAUSAL_MASK: + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + causal_offset = seqlen_k - seqlen_q + max_blocks_q = cute.ceil_div( + (blk_coord[0] + 1) * tile_shape[0] + causal_offset, tile_shape[1] + ) + result = cutlass.min(max_blocks_k, max_blocks_q) + elif mask_type == MaskType.SLIDING_WINDOW_MASK: + qk_offset = seqlen_k - seqlen_q + first_q = blk_coord[0] * tile_shape[0] + qk_offset + last_q = (blk_coord[0] + 1) * tile_shape[0] - 1 + qk_offset + min_kv = cutlass.max(0, first_q - window_left) + max_kv = cutlass.min(seqlen_k - 1, last_q + window_left) + start_block = min_kv // tile_shape[1] + end_block = cute.ceil_div(max_kv + 1, tile_shape[1]) + result = end_block - start_block + return result + + +@cute.jit +def get_masked_trip_count( + mask_type: MaskType, + window_left: int, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + seqlen_q: Int32 = 0, +) -> Int32: + """Number of masked (boundary) KV tile blocks.""" + result = 0 + if mask_type == MaskType.NO_MASK: + result = 0 + elif mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + elif mask_type == MaskType.CAUSAL_MASK: + trip_count = get_trip_count( + mask_type, window_left, blk_coord, tile_shape, seqlen_k, seqlen_q + ) + causal_offset = seqlen_k - seqlen_q + first_boundary = (blk_coord[0] * tile_shape[0] + causal_offset) // tile_shape[1] + last_boundary = ( + (blk_coord[0] + 1) * tile_shape[0] - 1 + causal_offset + ) // tile_shape[1] + result = cutlass.min( + trip_count, + last_boundary - first_boundary + 1, + ) + elif mask_type == MaskType.SLIDING_WINDOW_MASK: + trip_count = get_trip_count( + mask_type, window_left, blk_coord, tile_shape, seqlen_k, seqlen_q + ) + result = trip_count + return result + + +@cute.jit +def get_unmasked_trip_count( + mask_type: MaskType, + window_left: int, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + seqlen_q: Int32 = 0, +) -> Int32: + """Number of fully unmasked KV tile blocks.""" + result = 0 + if mask_type == MaskType.NO_MASK: + result = get_trip_count(mask_type, window_left, blk_coord, tile_shape, seqlen_k) + elif mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = ( + get_trip_count(mask_type, window_left, blk_coord, tile_shape, seqlen_k) + - 1 + ) + else: + result = get_trip_count( + mask_type, window_left, blk_coord, tile_shape, seqlen_k + ) + elif mask_type == MaskType.CAUSAL_MASK: + result = get_trip_count( + mask_type, window_left, blk_coord, tile_shape, seqlen_k, seqlen_q + ) - get_masked_trip_count( + mask_type, window_left, blk_coord, tile_shape, seqlen_k, seqlen_q + ) + elif mask_type == MaskType.SLIDING_WINDOW_MASK: + result = 0 + return result + + +@cute.jit +def get_kv_start_block_idx( + mask_type: MaskType, + window_left: int, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + seqlen_q: Int32 = 0, +) -> Int32: + """Starting KV block index (nonzero only for sliding window).""" + if cutlass.const_expr(mask_type == MaskType.SLIDING_WINDOW_MASK): + qk_offset = seqlen_k - seqlen_q + first_q = blk_coord[0] * tile_shape[0] + qk_offset + min_kv = cutlass.max(0, first_q - window_left) + return min_kv // tile_shape[1] + else: + return 0 + + +@cute.jit +def apply_mask( + mask_type: MaskType, + window_left: int, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_k: Int32, + causal_offset: Int32 = 0, +): + """Apply attention mask (causal, residual, or sliding window) to scores.""" + if mask_type == MaskType.RESIDUAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + elif mask_type == MaskType.CAUSAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[0] + causal_offset < pos[1] or pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + elif mask_type == MaskType.SLIDING_WINDOW_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if ( + pos[1] - pos[0] - causal_offset > window_left + or pos[0] + causal_offset - pos[1] > window_left + or pos[1] >= seqlen_k + ): + acc_qk[i] = -Float32.inf diff --git a/flashinfer/cute_dsl/attention/fusion/variant.py b/flashinfer/cute_dsl/attention/fusion/variant.py new file mode 100644 index 0000000000..47873edd84 --- /dev/null +++ b/flashinfer/cute_dsl/attention/fusion/variant.py @@ -0,0 +1,667 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""AttentionVariant — customization point for attention behavior. + +Subclass AttentionVariant to create custom attention behaviors. The hooks +are co-defined on a single object so that coupled invariants are naturally +enforced. + +Execution Order +=============== + +For each query row, the kernel iterates over KV tiles. Within each tile: + + 1. ``update_statistics(kv_tile_idx, qo_head_idx, m, d, scale)`` + Modify the online softmax running statistics *before* the tile's QK + scores are processed. Use this to inject virtual tokens (e.g. attention + sink) into the softmax denominator. + + 2. Masking (causal / sliding window / residual) is applied to QK scores. + + 3. ``score_mod(score, batch, qo, kv, qo_head, kv_head)`` — optional + element-wise modification of QK scores (e.g. ALiBi bias, soft-capping, + relative positional encoding). Runs *before* the score-to-weight + conversion. **Composes with both** standard softmax and custom + ``transform_logits``. + + 4. Score-to-weight conversion: + - Default (``has_logits_transform = False``): + Row-max reduction, then ``weights = exp2(score * scale - m * scale)``, + row-sum accumulation, correction warp rescaling. + - Custom (``has_logits_transform = True``): + ``weights = transform_logits(score)`` + Your transform **replaces** the entire softmax machinery (row-max, + exp2, row-sum, correction). Must produce non-negative values. + +After all KV tiles (final epilogue): + + 5. Output normalization: + - Default (``has_output_transform = False``): + ``output *= scale_output / d`` + - Custom (``has_output_transform = True``): + ``output = transform_output(output, batch, qo, qo_head, m, rcp_d, scale)`` + +Composability +============= + +``score_mod`` and ``transform_logits`` are **composable**: + +- ``score_mod`` modifies scores (position-dependent bias, capping, etc.) +- ``transform_logits`` replaces the activation function (sigmoid, relu, etc.) + +When both are set, scores flow through ``score_mod`` first, then +``transform_logits``. When only ``score_mod`` is set, scores flow into +the standard softmax (exp2) path. + +Variable Domains +================ + +``m`` (row_max) + Raw logit domain — the actual maximum QK dot-product value, + **not** multiplied by ``scale``. + +``d`` (row_sum) + Accumulated sum of ``exp2((score - m) * scale)`` across all tiles. + When ``update_statistics`` injects virtual tokens, ``d`` includes their + contributions. + +``scale`` + ``log2(e) * sm_scale`` where ``sm_scale = 1 / sqrt(head_dim)``. + The kernel uses base-2 exponentials for hardware efficiency: + ``exp(x * sm_scale) == exp2(x * sm_scale * log2(e)) == exp2(x * scale)``. + +``rcp_d`` + ``1.0 / d`` — reciprocal of the softmax denominator. + +``scale_output`` + Output scaling factor, typically ``1.0``. + +Coupling Rules +============== + +- If ``update_statistics`` modifies ``d`` (adds virtual tokens), then + ``transform_output`` **must** account for the modified denominator. + The default ``output * scale_output / d`` works when ``d`` is unmodified. + With sink tokens, override ``transform_output`` to use + ``output * scale * rcp_d`` (where ``rcp_d = 1/d`` already reflects the + sink contribution). + +- If ``transform_logits`` is provided, it replaces the **entire** + ``exp2(score * scale - m * scale)`` conversion. The correction warp will + **not** rescale intermediate outputs. Your transform must produce + non-negative values for correct accumulation. + +Runtime Parameters +================== + +Variants that need runtime tensor data (e.g. ALiBi slopes, sink values, +RPE tables) expose it via the ``extra_params`` property. The wrapper +converts the tensor to CuTe format and passes it through the kernel; the +variant accesses it as ``self.params`` inside ``@cute.jit`` methods. + +``extra_params`` + Python-side property: returns a ``torch.Tensor`` of any shape, or + ``None``. Read by the wrapper at ``plan()`` time. + +``self.params`` + JIT-side attribute: the kernel binds the CuTe tensor to this name + before invoking any variant hook. Access it with natural indexing + (e.g. ``self.params[head_idx]`` for 1-D, ``self.params[head, offset]`` + for 2-D). + +Compile-time scalars (set in ``__init__``, e.g. ``self.cap = 50.0``) +are traced directly by the JIT compiler — no ``extra_params`` needed. + +Hardware Primitives +=================== + +The following hardware-mapped primitives are available for use in +``transform_logits`` and ``score_mod``:: + + cute.arch.exp2(x) # MUFU.EX2 — base-2 exponential (approx) + cute.arch.rcp_approx(x) # MUFU.RCP — reciprocal (approx) + tanh_approx(x) # MUFU.TANH — hyperbolic tangent (approx) + +Each maps to a single-cycle MUFU instruction. Import ``tanh_approx`` +from this module. + +Examples +======== + +Sigmoid attention (exp2 + rcp, 2 MUFU ops/element):: + + class SigmoidAttention(AttentionVariant): + has_logits_transform = True + def __init__(self, scale=1.0, bias=0.0): + self.scale = scale * math.log2(math.exp(1.0)) + self.bias = bias * math.log2(math.exp(1.0)) + @cute.jit + def transform_logits(self, score): + return cute.arch.rcp_approx( + 1 + cute.arch.exp2(-(score * self.scale + self.bias))) + +Sigmoid attention via tanh (1 MUFU op/element):: + + class SigmoidTanhAttention(AttentionVariant): + has_logits_transform = True + def __init__(self, scale=1.0, bias=0.0): + self.half_scale = scale / 2.0 + self.half_bias = bias / 2.0 + @cute.jit + def transform_logits(self, score): + return 0.5 + 0.5 * tanh_approx( + score * self.half_scale + self.half_bias) + +ALiBi (score_mod with 1-D per-head slopes):: + + class ALiBiAttention(AttentionVariant): + has_score_mod = True + def __init__(self, alibi_slopes): + self._slopes = alibi_slopes + @property + def extra_params(self): + return self._slopes # (H,) + @cute.jit + def score_mod(self, score, batch_idx, qo_idx, kv_idx, + qo_head_idx, kv_head_idx): + return score + self.params[qo_head_idx] * (kv_idx - qo_idx) + +Soft-capping (compile-time scalars only, no extra_params):: + + class SoftCappingAttention(AttentionVariant): + has_score_mod = True + def __init__(self, cap=50.0): + self.cap = cap + self.rcp_cap = 1.0 / cap + @cute.jit + def score_mod(self, score, ...): + return self.cap * tanh_approx(score * self.rcp_cap) +""" + +import math +from typing import Any + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm +from cutlass.cute.typing import Float32 + + +@dsl_user_op +def tanh_approx(a, *, loc=None, ip=None): + """Hardware tanh via MUFU.TANH — single-cycle approximation (SM75+).""" + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +class AttentionVariant: + """Base class for attention variants. Subclass to customize behavior. + + Set the class-level flags to enable compile-time dead code elimination + via ``cutlass.const_expr``. Only override the methods corresponding to + flags you set to ``True``. + + Attributes + ---------- + has_score_mod : bool + ``True`` when ``score_mod`` is overridden. Composes with both + standard softmax and custom ``transform_logits``. + has_logits_transform : bool + ``True`` when ``transform_logits`` is overridden. Replaces the + entire softmax machinery (row-max, exp2, row-sum, correction). + has_vectorized_logits_transform : bool + ``True`` when ``transform_logits_vec`` is overridden. Implies + ``has_logits_transform = True``. The kernel calls + ``transform_logits_vec`` instead of the per-element + ``transform_logits``, enabling stride-2 iteration and packed + f32x2 operations for higher throughput. + has_statistics_update : bool + ``True`` when ``update_statistics`` is overridden. + has_output_transform : bool + ``True`` when ``transform_output`` is overridden. + """ + + has_score_mod: bool = False + has_logits_transform: bool = False + has_vectorized_logits_transform: bool = False + has_statistics_update: bool = False + has_output_transform: bool = False + + params: Any = None + + @property + def extra_params(self): + """Return a PyTorch tensor of runtime data for JIT methods, or None. + + The tensor can be any shape. It is converted to a CuTe tensor and + bound to ``self.params`` before JIT methods are called. Inside + ``@cute.jit`` methods, access it as ``self.params[...]``. + + Override this in subclasses that need runtime tensor data. + """ + return None + + @cute.jit + def score_mod(self, score, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx): + """Element-wise modification of QK scores. + + Composes with both standard softmax and custom ``transform_logits``. + The modified score feeds into whichever score-to-weight conversion + is active. + + Typical uses: ALiBi bias, relative positional encoding, soft-capping. + + Parameters + ---------- + score : float32 + Raw QK dot product for one (query, key) pair. + batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx : int32 + Position and head indices. + + Returns + ------- + float32 + Modified score. + """ + return score + + @cute.jit + def transform_logits(self, score): + """Coordinate-free activation replacing softmax. + + When overridden (with ``has_logits_transform = True``), this replaces + the standard ``exp2(score * scale - max * scale)`` computation and + the entire online softmax machinery (row-max, row-sum, correction). + + The kernel calls this in a tight loop over register elements without + coordinate lookups. Position-dependent modifications belong in + ``score_mod``, which runs before this and composes naturally. + + Parameters + ---------- + score : float32 + QK dot product (possibly modified by ``score_mod``). + + Returns + ------- + float32 + Transformed score (must be non-negative for correct accumulation). + """ + return score + + @cute.jit + def transform_logits_vec(self, scores): + """Vectorized logits transform over a register fragment. + + Optional performance override. When provided (with + ``has_vectorized_logits_transform = True``), the kernel calls this + instead of the per-element ``transform_logits``, enabling stride-2 + iteration and packed f32x2 operations. + + Like ``transform_logits``, this is coordinate-free: the fragment + contains raw scores (possibly already modified by ``score_mod``). + + **Why this exists:** The CuTe DSL compiler (as of v4.3.5) does not + auto-pack adjacent scalar ``fma.rn.f32`` instructions into + ``fma.rz.ftz.f32x2`` packed ALU ops. This override lets you + explicitly use ``cute.arch.fma_packed_f32x2`` for ~2x ALU + throughput on the arithmetic surrounding MUFU calls. Once the + compiler learns to pack scalar FMAs automatically, this override + will become unnecessary and can be deprecated. + + Parameters + ---------- + scores : cute.Tensor (register fragment) + Mutable fragment of scores. Modify elements in-place. + """ + pass + + @cute.jit + def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale): + """Modify online softmax running statistics before processing a KV tile. + + Called once per KV tile, before the tile's scores are loaded. Use to + inject virtual tokens into the softmax computation (e.g. attention + sink). + + Parameters + ---------- + kv_tile_idx : int32 + Index of the current KV tile (0-based). + qo_head_idx : int32 + Query/output head index. + m : float32 + Current row maximum (raw logit domain, **not** scaled). + d : float32 + Current row sum of exponentiated scores. + scale : float32 + ``log2(e) * sm_scale``, the base-2 scaling factor. + + Returns + ------- + tuple[float32, float32] + ``(m_new, d_new)`` — updated running statistics. + """ + return m, d + + @cute.jit + def transform_output(self, output, batch_idx, qo_idx, qo_head_idx, m, rcp_d, scale): + """Element-wise transform on final output values. + + Called once per output element after all KV tiles are processed, + replacing the default ``output *= scale_output / d`` normalization. + + Parameters + ---------- + output : float32 + Accumulated attention output value (unnormalized). + batch_idx, qo_idx, qo_head_idx : int32 + Batch, position, and head indices. + m : float32 + Final row maximum (raw logit domain). + rcp_d : float32 + ``1.0 / row_sum`` (reciprocal of softmax denominator). + scale : float32 + Output scaling factor (``scale_output``, typically ``1.0``). + + Returns + ------- + float32 + Transformed output value. + """ + return output * rcp_d + + +class StandardAttention(AttentionVariant): + """Standard softmax attention — no customization.""" + + pass + + +class AttentionWithSink(AttentionVariant): + """Attention with a virtual sink token per head. + + Adds a learnable per-head bias to the softmax denominator on the first + KV tile, preventing attention collapse in long sequences. + + ``update_statistics`` injects the sink into the running ``(m, d)`` + and ``transform_output`` normalises with the modified denominator. + + Parameters + ---------- + sink : torch.Tensor + 1-D tensor of shape ``(num_qo_heads,)`` with per-head sink values. + + Usage:: + + wrapper.plan(..., variant=AttentionWithSink(sink_tensor)) + o = wrapper.run(q, k, v) + """ + + has_statistics_update = True + has_output_transform = True + + def __init__(self, sink): + self._sink = sink + + @property + def extra_params(self): + return self._sink + + @cute.jit + def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale): + # Guard: on non-first tiles, return (m, d) unchanged. Computing + # with sink_raw = -inf when m is also -inf (initial state on split-KV + # CTAs that don't own tile 0) would produce -inf - (-inf) = NaN. + m_new = m + d_new = d + if kv_tile_idx == 0: + log2_e = math.log2(math.exp(1.0)) + sink_raw = self.params[qo_head_idx] * log2_e / scale + m_new = sink_raw if sink_raw > m else m + rescale = cute.arch.exp2((m - m_new) * scale) + d_new = cute.arch.exp2((sink_raw - m_new) * scale) + d * rescale + return m_new, d_new + + @cute.jit + def transform_output(self, output, batch_idx, qo_idx, qo_head_idx, m, rcp_d, scale): + return output * scale * rcp_d + + +class SigmoidAttention(AttentionVariant): + """Sigmoid logits transform — replaces softmax with element-wise sigmoid. + + Uses ``rcp_approx(1 + exp2(-x))`` (2 MUFU ops per element). + For a faster variant using tanh (1 MUFU op), see ``SigmoidTanhAttention``. + + Composes with ``score_mod`` — set both ``has_score_mod`` and + ``has_logits_transform`` on a subclass to combine position-dependent + score modification with sigmoid activation. + + Parameters + ---------- + scale : float + Multiplicative scale applied to QK scores before sigmoid. + Typically ``1.0`` (or ``sm_scale`` if you want to bake it in). + bias : float + Additive bias applied to scaled scores before sigmoid. + + Usage:: + + wrapper.plan(..., sm_scale=1.0, variant=SigmoidAttention(scale=1.0)) + """ + + has_logits_transform = True + has_vectorized_logits_transform = True + + def __init__(self, scale: float = 1.0, bias: float = 0.0): + self.scale = scale * math.log2(math.exp(1.0)) + self.bias = bias * math.log2(math.exp(1.0)) + + @cute.jit + def transform_logits(self, score): + return cute.arch.rcp_approx( + 1 + cute.arch.exp2(-(score * self.scale + self.bias)) + ) + + @cute.jit + def transform_logits_vec(self, scores): + for i in cutlass.range_constexpr(0, cute.size(scores), 2): + scores[i] = cute.arch.rcp_approx( + 1 + cute.arch.exp2(-(scores[i] * self.scale + self.bias)) + ) + scores[i + 1] = cute.arch.rcp_approx( + 1 + cute.arch.exp2(-(scores[i + 1] * self.scale + self.bias)) + ) + + +class SigmoidTanhAttention(AttentionVariant): + """Sigmoid via tanh — MUFU-efficient sigmoid that replaces softmax. + + Uses the identity ``sigmoid(x) = 0.5 + 0.5 * tanh(x / 2)`` to replace + the exp2 + rcp_approx pair (2 MUFU ops/element) with a single tanh_approx + (1 MUFU op/element), matching softmax's MUFU budget. + + Parameters + ---------- + scale : float + Multiplicative scale applied to QK scores before sigmoid. + Typically ``1.0`` (or ``sm_scale`` if you want to bake it in). + bias : float + Additive bias applied to scaled scores before sigmoid. + + Usage:: + + wrapper.plan(..., sm_scale=1.0, variant=SigmoidTanhAttention(scale=1.0)) + """ + + has_logits_transform = True + has_vectorized_logits_transform = True + + def __init__(self, scale: float = 1.0, bias: float = 0.0): + self.half_scale = scale / 2.0 + self.half_bias = bias / 2.0 + + @cute.jit + def transform_logits(self, score): + return 0.5 + 0.5 * tanh_approx(score * self.half_scale + self.half_bias) + + @cute.jit + def transform_logits_vec(self, scores): + for i in cutlass.range_constexpr(0, cute.size(scores), 2): + scores[i], scores[i + 1] = cute.arch.fma_packed_f32x2( + (scores[i], scores[i + 1]), + (self.half_scale, self.half_scale), + (self.half_bias, self.half_bias), + ) + scores[i] = tanh_approx(scores[i]) + scores[i + 1] = tanh_approx(scores[i + 1]) + scores[i], scores[i + 1] = cute.arch.fma_packed_f32x2( + (scores[i], scores[i + 1]), + (0.5, 0.5), + (0.5, 0.5), + ) + + +class ALiBiAttention(AttentionVariant): + """ALiBi (Attention with Linear Biases) — adds position-dependent bias. + + Adds ``slope * (kv_pos - qo_pos)`` to each QK score before softmax. + This composes with the standard exp2 softmax path via ``score_mod``, + so the kernel's online softmax and correction logic remain unchanged. + + Parameters + ---------- + alibi_slopes : torch.Tensor + 1-D tensor of shape ``(num_qo_heads,)`` with per-head slopes. + + Usage:: + + slopes = ALiBiAttention.get_slopes(num_heads).cuda() + wrapper.plan(..., variant=ALiBiAttention(slopes)) + o = wrapper.run(q, k, v) + + Reference: https://arxiv.org/abs/2108.12409 + """ + + has_score_mod = True + + def __init__(self, alibi_slopes): + self._slopes = alibi_slopes + + @property + def extra_params(self): + return self._slopes + + @cute.jit + def score_mod(self, score, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx): + return score + self.params[qo_head_idx] * (kv_idx - qo_idx) + + @staticmethod + def get_slopes(num_heads: int): + """Return the standard ALiBi slope schedule for ``num_heads`` heads. + + When ``num_heads`` is a power of 2, slopes are + ``2^{-8/n}, 2^{-16/n}, ..., 2^{-8}`` where ``n = num_heads``. + Otherwise, the nearest larger power-of-2 slopes are interpolated. + """ + import torch + + def _get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return torch.tensor(_get_slopes_power_of_2(num_heads), dtype=torch.float32) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + slopes = _get_slopes_power_of_2(closest_power_of_2) + extra = _get_slopes_power_of_2(2 * closest_power_of_2) + slopes += extra[0::2][: num_heads - closest_power_of_2] + return torch.tensor(slopes[:num_heads], dtype=torch.float32) + + +class RPEAttention(AttentionVariant): + """Relative Positional Encoding via a learned per-head bias table. + + Adds ``rpe_table[head, clamp(kv - qo + max_rel_dist, 0, 2*max_rel_dist)]`` + to each QK score before softmax. Uses ``score_mod`` so it composes + with the standard softmax path. + + Parameters + ---------- + rpe_table : torch.Tensor + 2-D tensor of shape ``(num_qo_heads, 2 * max_rel_dist + 1)``. + max_rel_dist : int + Maximum relative distance. Positions beyond this are clamped. + + Usage:: + + wrapper.plan(..., variant=RPEAttention(rpe_table, max_rel_dist=64)) + o = wrapper.run(q, k, v) + """ + + has_score_mod = True + + def __init__(self, rpe_table, max_rel_dist: int): + self._rpe_table = rpe_table + self._offset = max_rel_dist + self._table_size = 2 * max_rel_dist + 1 + + @property + def extra_params(self): + return self._rpe_table + + @cute.jit + def score_mod(self, score, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx): + rel_pos = kv_idx - qo_idx + self._offset + rel_pos_clamped = rel_pos if rel_pos >= 0 else 0 + rel_pos_clamped = ( + rel_pos_clamped + if rel_pos_clamped < self._table_size + else self._table_size - 1 + ) + return score + self.params[qo_head_idx, rel_pos_clamped] + + +class SoftCappingAttention(AttentionVariant): + """Soft-capping — prevents logits from growing excessively large. + + Applies ``cap * tanh(score / cap)`` before softmax, bounding scores + to ``[-cap, +cap]``. Uses ``score_mod`` so it composes with the + standard softmax path. + + Parameters + ---------- + cap : float + Soft-capping value (e.g. 50.0 for Gemma-2). + + Usage:: + + wrapper.plan(..., variant=SoftCappingAttention(cap=50.0)) + o = wrapper.run(q, k, v) + + Reference: Gemma-2 (https://arxiv.org/abs/2408.00118) + """ + + has_score_mod = True + + def __init__(self, cap: float = 50.0): + self.cap = cap + self.rcp_cap = 1.0 / cap + + @cute.jit + def score_mod(self, score, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx): + return self.cap * tanh_approx(score * self.rcp_cap) diff --git a/flashinfer/cute_dsl/attention/mainloop_spec.py b/flashinfer/cute_dsl/attention/mainloop_spec.py new file mode 100644 index 0000000000..c5e550ce74 --- /dev/null +++ b/flashinfer/cute_dsl/attention/mainloop_spec.py @@ -0,0 +1,223 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MainloopSpec — the unit of composition for attention kernels. + +Analogous to C++ CUTLASS's CollectiveMainloop (e.g. +Sm100FmhaFwdMainloopTmaWarpspecialized), this bundles: +- PipelineTopology (which pipelines connect which warps) +- TmemLayout (TMEM allocation map) +- WarpSchedule (warp role assignment and register budgets) +- Stage counts and buffer sizes +""" + +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Dict, Union + +from .config import AttentionConfig +from .mla_config import MLAConfig +from .tmem_layout import TmemLayout +from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE +from .mla_warp_schedule import ( + MLAWarpSchedule, + MLA_DECODE_SCHEDULE, + MLAWarpScheduleFP8, + MLA_DECODE_FP8_SCHEDULE, +) +from .pipeline_topology import ( + PipelineTopology, + make_prefill_topology, + make_prefill_topology_transform, + make_mla_topology, + make_mla_fp8_topology, +) + + +@dataclass +class MainloopSpec: + """Bundles pipeline topology, TMEM layout, warp schedule, and stage counts. + + This is the Python-side equivalent of a C++ CUTLASS CollectiveMainloop + template class. The kernel takes a MainloopSpec and creates all pipelines, + tensors, and warp dispatch from it. + + Stage counts that depend on input dtype (e.g. kv_stages) are set by + calling `resolve(dtype_width)` before use. + """ + + config: AttentionConfig + warp_schedule: WarpSchedule + tmem_layout: TmemLayout + pipeline_topology: PipelineTopology + + has_logits_transform: bool = False + + q_stages: int = 2 + kv_stages: int = 3 + acc_stage: int = 1 + softmax_corr_stage: int = 1 + mma_corr_stage: int = 2 + mma_softmax_stage: int = 1 + epi_stage: int = 2 + buffer_align_bytes: int = 1024 + + def resolve(self, dtype_width: int) -> MainloopSpec: + """Return a new MainloopSpec with dtype-dependent stage counts resolved. + + Called after input dtype is known (inside __call__) but before + SharedStorage or pipeline creation. The original spec is not modified. + + :param dtype_width: Bit width of the input element type. + :returns: A new MainloopSpec with resolved kv_stages and pipeline_topology. + """ + kv_stages = 4 if dtype_width == 8 else 3 + if self.has_logits_transform: + topology = make_prefill_topology_transform( + self.warp_schedule, + q_stages=self.q_stages, + kv_stages=kv_stages, + mma_softmax_stages=self.mma_softmax_stage, + epi_stages=self.epi_stage, + ) + else: + topology = make_prefill_topology( + self.warp_schedule, + q_stages=self.q_stages, + kv_stages=kv_stages, + mma_softmax_stages=self.mma_softmax_stage, + softmax_corr_stages=self.softmax_corr_stage, + mma_corr_stages=self.mma_corr_stage, + epi_stages=self.epi_stage, + ) + return replace(self, kv_stages=kv_stages, pipeline_topology=topology) + + def barrier_stage_counts(self) -> Dict[str, int]: + """Return {edge_name: barrier_slot_count} for SharedStorage definition. + + This is used to size the barrier storage arrays in the SharedStorage struct. + """ + result = {} + for edge in self.pipeline_topology.edges: + result[edge.name] = edge.barrier_stages + return result + + +def make_prefill_mainloop_spec( + config: AttentionConfig, + warp_schedule: WarpSchedule | None = None, + has_logits_transform: bool = False, +) -> MainloopSpec: + """Create a MainloopSpec for FMHA prefill. + + :param config: Core attention configuration. + :param warp_schedule: Optional warp schedule override (defaults to PREFILL_SCHEDULE). + :param has_logits_transform: If True, uses transform topology (no correction warp). + """ + sched = warp_schedule if warp_schedule is not None else PREFILL_SCHEDULE + tmem = TmemLayout.from_config(config) + if has_logits_transform: + topo = make_prefill_topology_transform(sched) + else: + topo = make_prefill_topology(sched) + + return MainloopSpec( + config=config, + warp_schedule=sched, + tmem_layout=tmem, + pipeline_topology=topo, + has_logits_transform=has_logits_transform, + ) + + +# --------------------------------------------------------------------------- +# MLA Decode +# --------------------------------------------------------------------------- + + +@dataclass +class MLAMainloopSpec: + """Bundles pipeline topology and warp schedule for MLA decode kernels. + + Analogous to MainloopSpec but uses MLAConfig and MLAWarpSchedule. + MLA stage counts are fixed (not dtype-dependent), so resolve() is a + no-op that keeps the interface consistent with MainloopSpec. + """ + + config: MLAConfig + warp_schedule: Union[MLAWarpSchedule, MLAWarpScheduleFP8] + pipeline_topology: PipelineTopology + + buffer_align_bytes: int = 1024 + + def resolve(self, dtype_width: int) -> MLAMainloopSpec: + """No-op for MLA — stage counts are fixed. Returns self unchanged. + + Keeps the interface consistent with MainloopSpec so the kernel + can call mainloop.resolve() uniformly. + """ + return self + + def barrier_stage_counts(self) -> Dict[str, int]: + """Return {edge_name: barrier_slot_count} for SharedStorage definition.""" + return {edge.name: edge.barrier_stages for edge in self.pipeline_topology.edges} + + +def make_mla_mainloop_spec( + config: MLAConfig, + warp_schedule: MLAWarpSchedule | None = None, +) -> MLAMainloopSpec: + """Create an MLAMainloopSpec for MLA decode. + + :param config: MLA kernel configuration. + :param warp_schedule: Optional warp schedule override (defaults to MLA_DECODE_SCHEDULE). + """ + sched = warp_schedule if warp_schedule is not None else MLA_DECODE_SCHEDULE + topo = make_mla_topology( + sched, + load_q_stages=config.load_q_stage, + load_kv_stages=config.load_kv_stage, + mma_s_stages=config.mma_s_stage, + p_mma_stages=config.p_mma_stage, + p_cor_stages=config.p_cor_stage, + mma_o_stages=config.mma_o_stage, + load_pt_stages=config.load_pt_stage, + cluster_scale=config.cluster_shape_mnk[0], + ) + return MLAMainloopSpec( + config=config, + warp_schedule=sched, + pipeline_topology=topo, + ) + + +def make_mla_fp8_mainloop_spec( + config: MLAConfig, + warp_schedule: MLAWarpScheduleFP8 | None = None, +) -> MLAMainloopSpec: + """Create an MLAMainloopSpec for FP8 MLA decode. + + Uses the FP8-specific pipeline topology with separate load_k/load_v + pipelines and no load_pt pipeline. + + :param config: MLA kernel configuration (must have is_fp8=True). + :param warp_schedule: Optional warp schedule override (defaults to MLA_DECODE_FP8_SCHEDULE). + """ + sched = warp_schedule if warp_schedule is not None else MLA_DECODE_FP8_SCHEDULE + topo = make_mla_fp8_topology( + sched, + load_q_stages=config.load_q_stage, + load_k_stages=config.load_k_stage, + load_v_stages=config.load_v_stage, + mma_s_stages=config.mma_s_stage, + p_mma_stages=config.p_mma_stage, + p_cor_stages=config.p_cor_stage, + mma_o_stages=config.mma_o_stage, + cluster_scale=config.cluster_shape_mnk[0], + ) + return MLAMainloopSpec( + config=config, + warp_schedule=sched, + pipeline_topology=topo, + ) diff --git a/flashinfer/cute_dsl/attention/mla_config.py b/flashinfer/cute_dsl/attention/mla_config.py new file mode 100644 index 0000000000..ff2730d80d --- /dev/null +++ b/flashinfer/cute_dsl/attention/mla_config.py @@ -0,0 +1,229 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLAConfig — configuration dataclass for Multi-Head Latent Attention decode kernels. + +Separate concrete type from AttentionConfig, following the C++ CUTLASS pattern where +each mainloop variant has its own config type. The problem shapes, tile sizes, and +feature flags are fundamentally different between FMHA prefill and MLA decode. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Tuple, Type + +if TYPE_CHECKING: + import cutlass + + +@dataclass(frozen=True) +class MLAConfig: + """Configuration for MLA decode kernels. + + Encapsulates all parameters from the monolithic kernel's __init__ and + _setup_attributes into a single immutable object. Supports both FP16/BF16 + and FP8 variants via the ``is_fp8`` flag, which adjusts MMA tiler + dimensions, pipeline stage counts, and register budgets. + """ + + # Problem shape + latent_dim: int = 512 + rope_dim: int = 64 + + # Data types + acc_dtype: Type[cutlass.Numeric] = None # type: ignore[assignment] + lse_dtype: Type[cutlass.Numeric] = None # type: ignore[assignment] + + # MMA tile shapes + mma_qk_tiler_mn: Tuple[int, int] = (128, 128) + mma_pv_tiler_mn: Tuple[int, int] = (128, 256) + + # Execution parameters + max_active_clusters: int = 1 + page_size: int = 1 + skip_correction_threshold: float = 0.0 + is_persistent: bool = False + is_var_seq: bool = False + is_var_split_kv: bool = False + enable_pdl: bool = False + + # Cluster configuration + cluster_shape_mnk: Tuple[int, int, int] = (2, 1, 1) + use_2cta_instrs: bool = True + warps_in_n: int = 2 + + # FP8 flag — selects FP8-specific tiler, stages, and register budgets + is_fp8: bool = False + + # Pipeline stage counts — shared between FP16 and FP8 + load_q_stage: int = 1 + mma_s_stage: int = 2 + p_mma_stage: int = 2 + p_cor_stage: int = 2 + mma_o_stage: int = 1 # FP16: 1, FP8: 2 + + # Pipeline stage counts — FP16 only (unified K+V + page table pipelines) + load_kv_stage: int = 15 + load_pt_stage: int = 4 + + # Pipeline stage counts — FP8 only (separate K and V pipelines, no page table) + load_k_stage: int = 3 + load_v_stage: int = 2 + + # --- Derived properties --- + + @property + def mma_qk_tiler(self) -> Tuple[int, int, int]: + # FP8 doubles K-dim to pack latent into wider tiles + k = self.rope_dim * 2 if self.is_fp8 else self.rope_dim + return (self.mma_qk_tiler_mn[0], self.mma_qk_tiler_mn[1], k) + + @property + def mma_qk_rope_tiler(self) -> Tuple[int, int, int]: + # Rope tiler always uses rope_dim as K-dim (not doubled for FP8) + return (self.mma_qk_tiler_mn[0], self.mma_qk_tiler_mn[1], self.rope_dim) + + @property + def mma_pv_tiler(self) -> Tuple[int, int, int]: + return ( + self.mma_pv_tiler_mn[0], + self.mma_pv_tiler_mn[1], + self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], + ) + + @property + def iterations_qk_latent(self) -> int: + return self.latent_dim // self.mma_qk_tiler[2] + + @property + def iterations_qk_rope(self) -> int: + # FP8: rope fits in one iteration due to doubled K-dim + return 1 if self.is_fp8 else self.rope_dim // self.mma_qk_tiler[2] + + @property + def iterations_qk(self) -> int: + return self.iterations_qk_latent + self.iterations_qk_rope + + @property + def iterations_pv_k(self) -> int: + return self.mma_qk_tiler[1] // self.mma_pv_tiler[2] + + @property + def iterations_pv_n(self) -> int: + return self.latent_dim // self.mma_pv_tiler[1] + + @property + def tmem_o_offset(self) -> int: + return self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n + + @property + def correction_factor_offset(self) -> int: + return self.tmem_o_offset + self.latent_dim // self.warps_in_n + + @property + def num_compute_warps(self) -> int: + return 4 + + @property + def per_iteration_mma_o(self) -> bool: + """FP8 uses mma_o_stage=2 with per-iteration pipeline wait/release.""" + return self.mma_o_stage > 1 + + @property + def correction_reg_num(self) -> int: + return 256 if self.is_fp8 else 208 + + @property + def other_reg_num(self) -> int: + return 48 if self.is_fp8 else 96 + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + """Check if the FP16/BF16 MLA kernel can be implemented.""" + import cutlass as _cutlass + + if L != 512 or R != 64: + return False + if in_dtype not in [_cutlass.Float16, _cutlass.BFloat16]: + return False + if out_dtype not in [_cutlass.Float16, _cutlass.BFloat16]: + return False + if acc_dtype != _cutlass.Float32 or lse_dtype != _cutlass.Float32: + return False + if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and not is_var_seq: + return False + if H > 128 or (H < 128 and split_kv != 1): + return False + if S < 1 or S > 4: + return False + if K <= 0: + return False + return True + + @staticmethod + def can_implement_fp8( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + """Check if the FP8 MLA kernel can be implemented.""" + import cutlass as _cutlass + + if L != 512 or R != 64: + return False + if in_dtype not in [_cutlass.Float8E4M3FN]: + return False + if out_dtype not in [_cutlass.Float8E4M3FN, _cutlass.BFloat16]: + return False + if acc_dtype != _cutlass.Float32 or lse_dtype != _cutlass.Float32: + return False + if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and not is_var_seq: + return False + if H > 128 or (H < 128 and split_kv != 1): + return False + if S <= 0 or S > 4: + return False + if K <= 0: + return False + return True diff --git a/flashinfer/cute_dsl/attention/mla_decode.py b/flashinfer/cute_dsl/attention/mla_decode.py new file mode 100644 index 0000000000..93826a01e5 --- /dev/null +++ b/flashinfer/cute_dsl/attention/mla_decode.py @@ -0,0 +1,829 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Modular MLA decode kernel — composes role-based building blocks. + +This is the top-level kernel that wires together the modular MLA building blocks +(config, schedule, mainloop spec, collective builder, roles) into a launchable +attention kernel. It follows the same pattern as the FMHA prefill kernel in +prefill.py, but for Multi-Head Latent Attention decode with paged KV cache. +""" + +from typing import Type, Tuple, Optional +from types import SimpleNamespace + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from .mla_config import MLAConfig +from .config import AttentionFusion +from .mla_warp_schedule import MLAWarpSchedule, MLA_DECODE_SCHEDULE +from .mainloop_spec import make_mla_mainloop_spec +from .collective_builder import build_mla_launch_params +from .roles.mla_pt_loader import MLAPageTableLoaderRole +from .roles.mla_loader import MLALoaderRole +from .roles.mla_mma import MLAMmaRole +from .roles.mla_compute import MLAComputeRole +from .roles.mla_correction import MLACorrectionRole +from .scheduler.mla_persistent import ( + LOG2_E, + MAX_SPLITS, + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler_params, + mla_get_split_kv, + mla_get_split_kv_simplified, + mla_get_workspace_size, +) + +import warnings + +warnings.filterwarnings( + "ignore", + message="This loop is no longer unrolled and may cause performance regression", +) + + +from .compat import ( + setmaxregister_decrease as _setmaxregister_decrease, + setmaxregister_increase as _setmaxregister_increase, + get_max_tmem_alloc_cols as _get_max_tmem_alloc_cols, +) + + +class BlackwellMultiLatentAttentionForward: + """Modular MLA decode kernel composing role-based building blocks. + + Follows the same compositional pattern as BlackwellFusedMultiHeadAttentionForward + in prefill.py, but for Multi-Head Latent Attention decode with paged KV cache + and split-KV reduction. + """ + + def __init__( + self, + config: MLAConfig, + fusion: AttentionFusion | None = None, + schedule: MLAWarpSchedule | None = None, + ): + self.config = config + self.fusion = fusion if fusion is not None else AttentionFusion() + self.schedule = schedule if schedule is not None else MLA_DECODE_SCHEDULE + self.mainloop = make_mla_mainloop_spec(config, self.schedule) + ( + self.tmem_ptr_sync_bar, + self.softmax_exchange_sync_bar, + self.epilogue_exchange_sync_bar, + ) = self.schedule.make_named_barriers() + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + params_in: Optional[cute.Tensor], + stream, + ): + self.q_dtype: Type[cutlass.Numeric] = q_latent.element_type + self.k_dtype: Type[cutlass.Numeric] = c_latent.element_type + self.v_dtype: Type[cutlass.Numeric] = c_latent.element_type + self.o_dtype: Type[cutlass.Numeric] = o.element_type + + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} " + f"or {self.q_dtype} != {self.v_dtype}" + ) + + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] + def _reinterpret_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), + stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), + ), + ) + + q_latent = _reinterpret_4d(q_latent) + q_rope = _reinterpret_4d(q_rope) + o = _reinterpret_4d(o) + + # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] + def _reinterpret_3d_kv(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[1], t.shape[2], t.shape[0]), + stride=(t.stride[1], t.stride[2], t.stride[0]), + ), + ) + + c_latent = _reinterpret_3d_kv(c_latent) + c_rope = _reinterpret_3d_kv(c_rope) + + # Reinterpret contiguous [B, page_count] as [page_count, B] + page_table = cute.make_tensor( + page_table.iterator, + cute.make_layout( + (page_table.shape[1], page_table.shape[0]), + stride=(page_table.stride[1], page_table.stride[0]), + ), + ) + + # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[2], lse.shape[1], lse.shape[0]), + stride=(lse.stride[2], lse.stride[1], lse.stride[0]), + ), + ) + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + q_latent.shape[3], + split_kv, + self.config.acc_dtype, + workspace, + ) + + c_latent_transpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_transpose_layout + ) + + self.mainloop = self.mainloop.resolve(self.q_dtype.width) + + params = ( + cute.make_tensor( + params_in.iterator, + cute.make_layout( + self.fusion.params_shape, + stride=self.fusion.params_strides, + ), + ) + if cutlass.const_expr(self.fusion.has_params) + else None + ) + + self.pt_loader_role = MLAPageTableLoaderRole(self.config) + self.loader_role = MLALoaderRole(self.config) + self.mma_role = MLAMmaRole(self.config, self.mainloop) + self.compute_role = MLAComputeRole(self.config, fusion=self.fusion) + self.compute_role.set_dtypes(self.q_dtype) + self.compute_role.set_barriers(self.softmax_exchange_sync_bar) + self.correction_role = MLACorrectionRole( + self.config, + fusion=self.fusion, + v_dtype=self.v_dtype, + o_dtype=self.o_dtype, + ) + self.correction_role.set_barriers(self.epilogue_exchange_sync_bar) + + lp = build_mla_launch_params( + self.mainloop, + self.schedule, + q_latent, + q_rope, + c_latent, + c_rope, + c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + self.q_dtype, + self.k_dtype, + self.v_dtype, + self.o_dtype, + ) + self.shared_storage = lp.SharedStorage + self.tma_copy_q_bytes = lp.tma_copy_q_bytes + self.tma_copy_kc_bytes = lp.tma_copy_kc_bytes + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.config.cluster_shape_mnk, + self.config.max_active_clusters, + self.config.is_persistent, + ) + + softmax_scale_log2 = softmax_scale * LOG2_E + self.split_kv_kernel( + lp.qk_tiled_mma, + lp.pv_tiled_mma, + lp.tma_atom_q_latent, + lp.tma_tensor_q_latent, + lp.tma_atom_q_rope, + lp.tma_tensor_q_rope, + lp.tma_atom_c_latent, + lp.tma_tensor_c_latent, + lp.tma_atom_c_rope, + lp.tma_tensor_c_rope, + lp.tma_atom_c_latent_transpose, + lp.tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + lp.q_latent_smem_layout_staged, + lp.q_rope_smem_layout_staged, + lp.kc_smem_layout_staged, + lp.p_smem_layout_staged, + lp.vc_smem_layout_staged, + lp.kc_smem_layout_for_tma, + lp.vc_smem_layout_for_tma, + lp.cta_layout_vmnk, + tile_sched_params, + lp.SharedStorage, + params, + ).launch( + grid=grid, + block=[self.schedule.threads_per_cta, 1, 1], + cluster=self.config.cluster_shape_mnk, + smem=lp.SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.config.enable_pdl, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=( + q_latent.shape[0], + q_latent.shape[2], + q_latent.shape[3], + ), + block=[ + self.schedule.threads_per_warp * self.config.num_compute_warps, + 1, + 1, + ], + smem=MAX_SPLITS * self.config.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.config.enable_pdl, + ) + + @cute.jit + def _create_pipelines(self, storage, cta_layout_vmnk): + """Create all inter-warp pipelines from the topology and storage barriers.""" + barrier_ptrs = { + edge.name: getattr(storage, edge.barrier_field_name).data_ptr() + for edge in self.mainloop.pipeline_topology.edges + } + tx_counts = {"q": self.tma_copy_q_bytes, "kv": self.tma_copy_kc_bytes} + return self.mainloop.pipeline_topology.create_pipelines( + barrier_ptrs, + tx_counts, + self.schedule.threads_per_warp, + cta_layout_vmnk, + ) + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_latent_smem_layout_staged: cute.ComposedLayout, + q_rope_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_for_tma: cute.ComposedLayout, + vc_smem_layout_for_tma: cute.ComposedLayout, + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + params: Optional[cute.Tensor] = None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + if warp_idx == self.schedule.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_bar, + allocator_warp_id=self.schedule.mma_warp_id, + is_two_cta=self.config.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + pipes = self._create_pipelines(storage, cta_layout_vmnk) + load_q_prod, load_q_cons = pipes["load_q"] + load_kv_prod, load_kv_cons = pipes["load_kv"] + mma_s_prod, mma_s_cons = pipes["mma_s"] + p_mma_prod, p_mma_cons = pipes["p_mma"] + p_cor_prod, p_cor_cons = pipes["p_cor"] + mma_o_prod, mma_o_cons = pipes["mma_o"] + load_pt_prod, load_pt_cons = pipes["load_pt"] + + pipeline_init_arrive( + cluster_shape_mn=self.config.cluster_shape_mnk, + is_relaxed=True, + ) + + # SMEM tensor views + sQ = storage.smem_q_latent.get_tensor( + q_latent_smem_layout_staged.outer, + swizzle=q_latent_smem_layout_staged.inner, + ) + sQ_rope = storage.smem_q_rope.get_tensor( + q_rope_smem_layout_staged.outer, + swizzle=q_rope_smem_layout_staged.inner, + ) + sKC = storage.smem_kc.get_tensor( + kc_smem_layout_staged.outer, + swizzle=kc_smem_layout_staged.inner, + ) + sKC_for_tma = storage.smem_kc.get_tensor( + kc_smem_layout_for_tma.outer, + swizzle=kc_smem_layout_for_tma.inner, + ) + sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner) + sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer) + sVC_for_tma = cute.make_tensor(sVC_ptr, vc_smem_layout_for_tma.outer) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, + swizzle=p_smem_layout_staged.inner, + ) + sPT = storage.smem_page_table.get_tensor( + cute.make_layout( + (self.config.mma_qk_tiler[1] // 2, self.config.load_pt_stage) + ) + ) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout( + self.config.num_compute_warps * self.schedule.threads_per_warp + ) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout( + self.config.num_compute_warps * self.schedule.threads_per_warp + ) + ) + + pipeline_init_wait(cluster_shape_mn=self.config.cluster_shape_mnk) + + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_wait() + + # ///////////////////////////////////////////////////////////////////// + # Empty warps + # ///////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.empty_warp_ids[0] + and warp_idx <= self.schedule.empty_warp_ids[-1] + ): + _setmaxregister_decrease(self.schedule.other_reg_num) + + # ///////////////////////////////////////////////////////////////////// + # Page table loader warp + # ///////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.load_pt_warp_id: + _setmaxregister_decrease(self.schedule.other_reg_num) + self.pt_loader_role.run( + split_kv, + cache_seqs, + block_split_kvs, + load_pt_prod, + mPT, + sPT, + tile_sched_params, + ) + + # ///////////////////////////////////////////////////////////////////// + # TMA loader warp + # ///////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.load_tma_warp_id: + _setmaxregister_decrease(self.schedule.other_reg_num) + tma_common_params = SimpleNamespace( + mPT=mPT, + sPT=sPT, + ) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC_for_tma, + ) + tma_v_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCL=mCL, + mKR=mKR, + mCLT=mCLT, + sVC=sVC_for_tma, + ) + self.loader_role.run( + tma_common_params, + tma_qk_params, + tma_v_params, + split_kv, + cache_seqs, + block_split_kvs, + load_q_prod, + load_kv_prod, + load_pt_cons, + tile_sched_params, + ) + + # ///////////////////////////////////////////////////////////////////// + # MMA warp + # ///////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.mma_warp_id: + _setmaxregister_decrease(self.schedule.other_reg_num) + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.config.acc_dtype) + + self.mma_role.run( + tiled_mma_qk, + tiled_mma_pv, + load_q_cons, + load_kv_cons, + mma_s_prod, + p_mma_cons, + mma_o_prod, + split_kv, + cache_seqs, + block_split_kvs, + tile_sched_params, + sQ, + sQ_rope, + sKC, + sP, + sVC, + tmem_ptr, + is_leader_cta, + mCL.shape[1], + ) + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + # ///////////////////////////////////////////////////////////////////// + # Compute (softmax) warps + # ///////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.compute_warp_ids[0] + and warp_idx <= self.schedule.compute_warp_ids[-1] + ): + self.compute_role.run( + split_kv, + cache_seqs, + block_split_kvs, + tile_sched_params, + tmem_ptr=None, + mma_s_consumer=mma_s_cons, + p_mma_producer=p_mma_prod, + p_cor_producer=p_cor_prod, + softmax_smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + mCL=mCL, + K=None, + L=mCL.shape[1], + tiled_mma_qk=tiled_mma_qk, + sP=sP, + softmax_scale_log2=softmax_scale_log2, + tmem=tmem, + params=params, + ) + + # ///////////////////////////////////////////////////////////////////// + # Correction (rescale + epilogue) warps + # ///////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.correction_warp_ids[0] + and warp_idx <= self.schedule.correction_warp_ids[-1] + ): + _setmaxregister_increase(self.schedule.correction_reg_num) + tmem.wait_for_alloc() + tmem_ptr_corr = tmem.retrieve_ptr(self.config.acc_dtype) + + cta_m_offset = (bidx % cute.size(tiled_mma_qk.thr_id.shape)) * ( + self.config.mma_qk_tiler[0] // self.config.cluster_shape_mnk[0] + ) + corr_common_params = SimpleNamespace( + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + L=mCL.shape[1], + H=mQL.shape[0], + cta_m_offset=cta_m_offset, + ) + corr_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + self.correction_role.run( + split_kv, + cache_seqs, + block_split_kvs, + tile_sched_params, + tmem_ptr_corr, + p_cor_consumer=p_cor_cons, + mma_o_consumer=mma_o_cons, + compute_common_params=corr_common_params, + epilogue_params=corr_epilogue_params, + params=params, + ) + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """Reduction kernel that combines intermediate results from split-KV blocks.""" + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, bidy, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.config.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div( + cache_seqs[blk_coord[2]], self.config.mma_qk_tiler[1] + ) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.config.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.config.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_wait() + gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.schedule.threads_per_warp) + + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.config.lse_dtype + ) + lse_max = -self.config.lse_dtype.inf + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.schedule.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.config.lse_dtype.inf + ) + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.config.lse_dtype.inf else 0.0 + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.config.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 + else self.config.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.schedule.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.config.latent_dim, + self.schedule.threads_per_warp * self.config.num_compute_warps, + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.config.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = ( + tidx + + j * self.schedule.threads_per_warp * self.config.num_compute_warps + ) + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = ( + tidx + + j * self.schedule.threads_per_warp * self.config.num_compute_warps + ) + mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + return + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[3]), + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + S: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize workspace tensors acc_o and acc_lse for split-KV.""" + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + align = 256 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (H, split_kv, D, S, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(split_kv * H * D, align), + cute.assume(H * split_kv * S * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (H, split_kv, S, B), + stride=(split_kv, 1, H * split_kv, H * split_kv * S), + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + return mla_get_split_kv(B, S, K, mma_qk_tiler_mn, max_active_blocks) + + @staticmethod + def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: + return mla_get_split_kv_simplified(B, S, max_active_blocks) + + @staticmethod + def get_workspace_size( + H: int, + S: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + return mla_get_workspace_size(H, S, D, B, split_kv, acc_dtype.width) + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + return MLAConfig.can_implement( + B, + S, + K, + H, + L, + R, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + split_kv, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ) diff --git a/flashinfer/cute_dsl/attention/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/mla_decode_fp8.py new file mode 100644 index 0000000000..6e66f7e391 --- /dev/null +++ b/flashinfer/cute_dsl/attention/mla_decode_fp8.py @@ -0,0 +1,847 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Modular FP8 MLA decode kernel — composes role-based building blocks. + +This is the FP8 variant of the MLA decode kernel. It differs from the FP16 +variant (mla_decode.py) in warp assignments, pipeline topology, and MMA loop +structure, while sharing compute (softmax) and correction roles. + +Key structural differences from FP16: +- Two TMA loader warps (K and V) instead of TMA + page-table loaders +- Separate load_k and load_v pipelines instead of unified load_kv + load_pt +- Separate SMEM buffers for KC-latent, KC-rope, and VC (no aliasing) +- mma_o_stage=2 with per-iteration pipeline wait/release +- QK rope tiler K-dim doubled (128 vs 64), iterations_qk_rope=1 + +AI-assisted port from monolithic mla_decode_fp8.py to the modular framework. +""" + +from typing import Type, Tuple, Optional +from types import SimpleNamespace + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from .mla_config import MLAConfig +from .config import AttentionFusion +from .mla_warp_schedule import MLAWarpScheduleFP8, MLA_DECODE_FP8_SCHEDULE +from .mainloop_spec import make_mla_fp8_mainloop_spec +from .collective_builder import build_mla_fp8_launch_params +from .roles.mla_loader_fp8 import MLAFP8LoaderKRole, MLAFP8LoaderVRole +from .roles.mla_mma_fp8 import MLAMmaFP8Role +from .roles.mla_compute import MLAComputeRole +from .roles.mla_correction import MLACorrectionRole +from .scheduler.mla_persistent import ( + LOG2_E, + MAX_SPLITS, + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler_params, + mla_get_split_kv, + mla_get_split_kv_simplified, + mla_get_workspace_size, +) + +import warnings + +warnings.filterwarnings( + "ignore", + message="This loop is no longer unrolled and may cause performance regression", +) + +from .compat import ( + setmaxregister_decrease as _setmaxregister_decrease, + setmaxregister_increase as _setmaxregister_increase, + get_max_tmem_alloc_cols as _get_max_tmem_alloc_cols, +) + + +class BlackwellMultiLatentAttentionForwardFP8: + """Modular FP8 MLA decode kernel composing role-based building blocks. + + Follows the same compositional pattern as the FP16 MLA decode kernel + but uses FP8-specific roles, pipeline topology, and warp schedule. + """ + + def __init__( + self, + config: MLAConfig, + fusion: AttentionFusion | None = None, + schedule: MLAWarpScheduleFP8 | None = None, + ): + self.config = config + self.fusion = fusion if fusion is not None else AttentionFusion() + self.schedule = schedule if schedule is not None else MLA_DECODE_FP8_SCHEDULE + self.mainloop = make_mla_fp8_mainloop_spec(config, self.schedule) + ( + self.tmem_ptr_sync_bar, + self.softmax_exchange_sync_bar, + self.epilogue_exchange_sync_bar, + ) = self.schedule.make_named_barriers() + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + params_in: Optional[cute.Tensor], + stream, + ): + self.q_dtype: Type[cutlass.Numeric] = q_latent.element_type + self.k_dtype: Type[cutlass.Numeric] = c_latent.element_type + self.v_dtype: Type[cutlass.Numeric] = c_latent.element_type + self.o_dtype: Type[cutlass.Numeric] = o.element_type + + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} " + f"or {self.q_dtype} != {self.v_dtype}" + ) + + def _reinterpret_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), + stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), + ), + ) + + q_latent = _reinterpret_4d(q_latent) + q_rope = _reinterpret_4d(q_rope) + o = _reinterpret_4d(o) + + def _reinterpret_3d_kv(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[1], t.shape[2], t.shape[0]), + stride=(t.stride[1], t.stride[2], t.stride[0]), + ), + ) + + c_latent = _reinterpret_3d_kv(c_latent) + c_rope = _reinterpret_3d_kv(c_rope) + + page_table = cute.make_tensor( + page_table.iterator, + cute.make_layout( + (page_table.shape[1], page_table.shape[0]), + stride=(page_table.stride[1], page_table.stride[0]), + ), + ) + + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[2], lse.shape[1], lse.shape[0]), + stride=(lse.stride[2], lse.stride[1], lse.stride[0]), + ), + ) + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + q_latent.shape[3], + split_kv, + self.config.acc_dtype, + workspace, + ) + + c_latent_transpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_transpose_layout + ) + + self.mainloop = self.mainloop.resolve(self.q_dtype.width) + + params = ( + cute.make_tensor( + params_in.iterator, + cute.make_layout( + self.fusion.params_shape, + stride=self.fusion.params_strides, + ), + ) + if cutlass.const_expr(self.fusion.has_params) + else None + ) + + self.loader_k_role = MLAFP8LoaderKRole(self.config) + self.loader_v_role = MLAFP8LoaderVRole(self.config) + self.mma_role = MLAMmaFP8Role(self.config, self.mainloop) + self.compute_role = MLAComputeRole(self.config, fusion=self.fusion) + self.compute_role.set_dtypes(self.q_dtype) + self.compute_role.set_barriers(self.softmax_exchange_sync_bar) + self.correction_role = MLACorrectionRole( + self.config, fusion=self.fusion, v_dtype=self.v_dtype, o_dtype=self.o_dtype + ) + self.correction_role.set_barriers(self.epilogue_exchange_sync_bar) + + lp = build_mla_fp8_launch_params( + self.mainloop, + self.schedule, + q_latent, + q_rope, + c_latent, + c_rope, + c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + self.q_dtype, + self.k_dtype, + self.v_dtype, + self.o_dtype, + ) + self.shared_storage = lp.SharedStorage + self.tma_copy_q_bytes = lp.tma_copy_q_bytes + self.tma_copy_kc_bytes = lp.tma_copy_kc_bytes + self.tma_copy_vc_bytes = lp.tma_copy_vc_bytes + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.config.cluster_shape_mnk, + self.config.max_active_clusters, + self.config.is_persistent, + ) + + softmax_scale_log2 = softmax_scale * LOG2_E + self.split_kv_kernel( + lp.qk_tiled_mma, + lp.pv_tiled_mma, + lp.tma_atom_q_latent, + lp.tma_tensor_q_latent, + lp.tma_atom_q_rope, + lp.tma_tensor_q_rope, + lp.tma_atom_c_latent, + lp.tma_tensor_c_latent, + lp.tma_atom_c_rope, + lp.tma_tensor_c_rope, + lp.tma_atom_c_latent_transpose, + lp.tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + lp.q_latent_smem_layout_staged, + lp.q_rope_smem_layout_staged, + lp.kc_latent_smem_layout_staged, + lp.kc_rope_smem_layout_staged, + lp.p_smem_layout_staged, + lp.vc_smem_layout_staged, + lp.kc_latent_smem_layout_for_tma, + lp.kc_rope_smem_layout_for_tma, + lp.vc_smem_layout_for_tma, + lp.cta_layout_vmnk, + tile_sched_params, + lp.SharedStorage, + params, + ).launch( + grid=grid, + block=[self.schedule.threads_per_cta, 1, 1], + cluster=self.config.cluster_shape_mnk, + smem=lp.SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.config.enable_pdl, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), + block=[ + self.schedule.threads_per_warp * self.config.num_compute_warps, + 1, + 1, + ], + smem=MAX_SPLITS * self.config.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.config.enable_pdl, + ) + + @cute.jit + def _create_pipelines(self, storage, cta_layout_vmnk): + barrier_ptrs = { + edge.name: getattr(storage, edge.barrier_field_name).data_ptr() + for edge in self.mainloop.pipeline_topology.edges + } + tx_counts = { + "q": self.tma_copy_q_bytes, + "kv": self.tma_copy_kc_bytes, + "vc": self.tma_copy_vc_bytes, + } + return self.mainloop.pipeline_topology.create_pipelines( + barrier_ptrs, + tx_counts, + self.schedule.threads_per_warp, + cta_layout_vmnk, + ) + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_latent_smem_layout_staged: cute.ComposedLayout, + q_rope_smem_layout_staged: cute.ComposedLayout, + kc_latent_smem_layout_staged: cute.ComposedLayout, + kc_rope_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + kc_latent_smem_layout_for_tma: cute.ComposedLayout, + kc_rope_smem_layout_for_tma: cute.ComposedLayout, + vc_smem_layout_for_tma: cute.ComposedLayout, + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + params: Optional[cute.Tensor] = None, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + if warp_idx == self.schedule.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_bar, + allocator_warp_id=self.schedule.mma_warp_id, + is_two_cta=self.config.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + pipes = self._create_pipelines(storage, cta_layout_vmnk) + load_q_prod, load_q_cons = pipes["load_q"] + load_k_prod, load_k_cons = pipes["load_k"] + load_v_prod, load_v_cons = pipes["load_v"] + mma_s_prod, mma_s_cons = pipes["mma_s"] + p_mma_prod, p_mma_cons = pipes["p_mma"] + p_cor_prod, p_cor_cons = pipes["p_cor"] + mma_o_prod, mma_o_cons = pipes["mma_o"] + + pipeline_init_arrive( + cluster_shape_mn=self.config.cluster_shape_mnk, + is_relaxed=True, + ) + + # SMEM tensor views + sQ = storage.smem_q_latent.get_tensor( + q_latent_smem_layout_staged.outer, + swizzle=q_latent_smem_layout_staged.inner, + ) + sQ_rope = storage.smem_q_rope.get_tensor( + q_rope_smem_layout_staged.outer, + swizzle=q_rope_smem_layout_staged.inner, + ) + sKC = storage.smem_kc_latent.get_tensor( + kc_latent_smem_layout_staged.outer, + swizzle=kc_latent_smem_layout_staged.inner, + ) + sKC_rope = storage.smem_kc_rope.get_tensor( + kc_rope_smem_layout_staged.outer, + swizzle=kc_rope_smem_layout_staged.inner, + ) + sKC_for_tma = storage.smem_kc_latent.get_tensor( + kc_latent_smem_layout_for_tma.outer, + swizzle=kc_latent_smem_layout_for_tma.inner, + ) + sKC_rope_for_tma = storage.smem_kc_rope.get_tensor( + kc_rope_smem_layout_for_tma.outer, + swizzle=kc_rope_smem_layout_for_tma.inner, + ) + sVC = storage.smem_vc.get_tensor( + vc_smem_layout_staged.outer, + swizzle=vc_smem_layout_staged.inner, + ) + sVC_for_tma = storage.smem_vc.get_tensor( + vc_smem_layout_for_tma.outer, + swizzle=vc_smem_layout_for_tma.inner, + ) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, + swizzle=p_smem_layout_staged.inner, + ) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout( + self.config.num_compute_warps * self.schedule.threads_per_warp + ) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout( + self.config.num_compute_warps * self.schedule.threads_per_warp + ) + ) + + pipeline_init_wait(cluster_shape_mn=self.config.cluster_shape_mnk) + + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_wait() + + # ///////////////////////////////////////////////////////////////////// + # Empty warps + # ///////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.empty_warp_ids[0] + and warp_idx <= self.schedule.empty_warp_ids[-1] + ): + _setmaxregister_decrease(self.schedule.other_reg_num) + + # ///////////////////////////////////////////////////////////////////// + # K loader warp (Q + K loading) + # ///////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.load_tma_k_warp_id: + _setmaxregister_decrease(self.schedule.other_reg_num) + tma_common_params = SimpleNamespace(mPT=mPT) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC_for_tma, + sKC_rope=sKC_rope_for_tma, + ) + self.loader_k_role.run( + tma_common_params, + tma_qk_params, + split_kv, + cache_seqs, + block_split_kvs, + load_q_prod, + load_k_prod, + tile_sched_params, + ) + + # ///////////////////////////////////////////////////////////////////// + # V loader warp + # ///////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.load_tma_v_warp_id: + _setmaxregister_decrease(self.schedule.other_reg_num) + tma_common_params = SimpleNamespace(mPT=mPT) + tma_v_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCLT=mCLT, + sVC=sVC_for_tma, + ) + self.loader_v_role.run( + tma_common_params, + tma_v_params, + split_kv, + cache_seqs, + block_split_kvs, + load_v_prod, + tile_sched_params, + ) + + # ///////////////////////////////////////////////////////////////////// + # MMA warp + # ///////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.mma_warp_id: + _setmaxregister_decrease(self.schedule.other_reg_num) + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.config.acc_dtype) + + self.mma_role.run( + tiled_mma_qk, + tiled_mma_pv, + load_q_cons, + load_k_cons, + load_v_cons, + mma_s_prod, + p_mma_cons, + mma_o_prod, + split_kv, + cache_seqs, + block_split_kvs, + tile_sched_params, + sQ, + sQ_rope, + sKC, + sKC_rope, + sP, + sVC, + tmem_ptr, + is_leader_cta, + mCL.shape[1], + ) + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + # ///////////////////////////////////////////////////////////////////// + # Compute (softmax) warps + # ///////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.compute_warp_ids[0] + and warp_idx <= self.schedule.compute_warp_ids[-1] + ): + # Fresh TiledMma avoids SSA dominance conflict with the MMA warp's + # .set(ACCUMULATE) mutations on the same tiled_mma_qk variable. + compute_tiled_mma_qk = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.config.acc_dtype, + tcgen05.CtaGroup.TWO, + self.config.mma_qk_tiler[:2], + ) + self.compute_role.run( + split_kv, + cache_seqs, + block_split_kvs, + tile_sched_params, + tmem_ptr=None, + mma_s_consumer=mma_s_cons, + p_mma_producer=p_mma_prod, + p_cor_producer=p_cor_prod, + softmax_smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + mCL=mCL, + K=None, + L=mCL.shape[1], + tiled_mma_qk=compute_tiled_mma_qk, + sP=sP, + softmax_scale_log2=softmax_scale_log2, + tmem=tmem, + params=params, + ) + + # ///////////////////////////////////////////////////////////////////// + # Correction (rescale + epilogue) warps + # ///////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.correction_warp_ids[0] + and warp_idx <= self.schedule.correction_warp_ids[-1] + ): + _setmaxregister_increase(self.schedule.correction_reg_num) + tmem.wait_for_alloc() + tmem_ptr_corr = tmem.retrieve_ptr(self.config.acc_dtype) + + cta_m_offset = (bidx % cute.size(tiled_mma_qk.thr_id.shape)) * ( + self.config.mma_qk_tiler[0] // self.config.cluster_shape_mnk[0] + ) + corr_common_params = SimpleNamespace( + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + L=mCL.shape[1], + H=mQL.shape[0], + cta_m_offset=cta_m_offset, + ) + corr_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + self.correction_role.run( + split_kv, + cache_seqs, + block_split_kvs, + tile_sched_params, + tmem_ptr_corr, + p_cor_consumer=p_cor_cons, + mma_o_consumer=mma_o_cons, + compute_common_params=corr_common_params, + epilogue_params=corr_epilogue_params, + params=params, + ) + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """Reduction kernel — identical to FP16 version.""" + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, bidy, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.config.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div( + cache_seqs[blk_coord[2]], self.config.mma_qk_tiler[1] + ) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.config.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.config.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_wait() + gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.schedule.threads_per_warp) + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.config.lse_dtype + ) + lse_max = -self.config.lse_dtype.inf + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.schedule.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.config.lse_dtype.inf + ) + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.config.lse_dtype.inf else 0.0 + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.config.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 + else self.config.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.schedule.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.config.latent_dim, + self.schedule.threads_per_warp * self.config.num_compute_warps, + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.config.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = ( + tidx + + j * self.schedule.threads_per_warp * self.config.num_compute_warps + ) + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = ( + tidx + + j * self.schedule.threads_per_warp * self.config.num_compute_warps + ) + mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + if cutlass.const_expr(self.config.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + return + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[3]), + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + S: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize workspace tensors acc_o and acc_lse for split-KV.""" + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + align = 256 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (H, split_kv, D, S, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(split_kv * H * D, align), + cute.assume(H * split_kv * S * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (H, split_kv, S, B), + stride=(split_kv, 1, H * split_kv, H * split_kv * S), + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + return mla_get_split_kv(B, S, K, mma_qk_tiler_mn, max_active_blocks) + + @staticmethod + def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: + return mla_get_split_kv_simplified(B, S, max_active_blocks) + + @staticmethod + def get_workspace_size( + H: int, + S: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + return mla_get_workspace_size(H, S, D, B, split_kv, acc_dtype.width) + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + return MLAConfig.can_implement_fp8( + B, + S, + K, + H, + L, + R, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + split_kv, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ) diff --git a/flashinfer/cute_dsl/attention/mla_warp_schedule.py b/flashinfer/cute_dsl/attention/mla_warp_schedule.py new file mode 100644 index 0000000000..39cc4edf6f --- /dev/null +++ b/flashinfer/cute_dsl/attention/mla_warp_schedule.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLAWarpSchedule — warp role assignment and register budgets for MLA decode. + +Separate concrete type from WarpSchedule (FMHA prefill). The MLA decode kernel +uses 12 warps with a fundamentally different role layout: +- 4 compute warps (softmax + exchange) +- 4 correction warps (rescale + epilogue) +- 1 MMA warp +- 1 TMA load warp +- 1 page table load warp +- 1 empty warp +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import cutlass.pipeline as pipeline + + +@dataclass(frozen=True) +class MLAWarpSchedule: + """Warp role assignment and register budgets for MLA decode kernels.""" + + compute_warp_ids: Tuple[int, ...] = (0, 1, 2, 3) + correction_warp_ids: Tuple[int, ...] = (4, 5, 6, 7) + mma_warp_id: int = 8 + load_tma_warp_id: int = 9 + load_pt_warp_id: int = 10 + empty_warp_ids: Tuple[int, ...] = (11,) + + softmax_reg_num: int = 192 + correction_reg_num: int = 208 + other_reg_num: int = 96 + + threads_per_warp: int = 32 + + # Named barrier IDs + tmem_ptr_sync_bar_id: int = 1 + softmax_exchange_bar_id: int = 2 + epilogue_exchange_bar_id: int = 3 + + @property + def all_warp_ids(self) -> Tuple[int, ...]: + return ( + *self.compute_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_tma_warp_id, + self.load_pt_warp_id, + *self.empty_warp_ids, + ) + + @property + def num_warps(self) -> int: + return len(self.all_warp_ids) + + @property + def threads_per_cta(self) -> int: + return self.threads_per_warp * self.num_warps + + @property + def num_compute_warps(self) -> int: + return len(self.compute_warp_ids) + + def make_named_barriers(self) -> Tuple[pipeline.NamedBarrier, ...]: + """Create the named barriers used by the MLA decode kernel. + + Returns (tmem_ptr_sync_bar, softmax_exchange_sync_bar, epilogue_exchange_sync_bar). + """ + n_compute = self.num_compute_warps + tpw = self.threads_per_warp + + # MMA warp + compute warps + correction warps synchronize TMEM pointer + tmem_ptr_sync = pipeline.NamedBarrier( + barrier_id=self.tmem_ptr_sync_bar_id, + num_threads=tpw + tpw * n_compute * 2, + ) + # Compute warps exchange row-max during softmax + softmax_exchange = pipeline.NamedBarrier( + barrier_id=self.softmax_exchange_bar_id, + num_threads=tpw * n_compute, + ) + # Correction warps exchange row-sum during epilogue + epilogue_exchange = pipeline.NamedBarrier( + barrier_id=self.epilogue_exchange_bar_id, + num_threads=tpw * n_compute, + ) + return tmem_ptr_sync, softmax_exchange, epilogue_exchange + + +MLA_DECODE_SCHEDULE = MLAWarpSchedule() + + +@dataclass(frozen=True) +class MLAWarpScheduleFP8: + """Warp role assignment and register budgets for FP8 MLA decode kernels. + + FP8 replaces the page-table loader warp with a second TMA loader warp + (separate K and V loading), eliminating the load_pt pipeline entirely. + """ + + compute_warp_ids: Tuple[int, ...] = (0, 1, 2, 3) + correction_warp_ids: Tuple[int, ...] = (4, 5, 6, 7) + mma_warp_id: int = 8 + load_tma_k_warp_id: int = 9 + load_tma_v_warp_id: int = 10 + empty_warp_ids: Tuple[int, ...] = (11,) + + softmax_reg_num: int = 192 + correction_reg_num: int = 256 + other_reg_num: int = 48 + + threads_per_warp: int = 32 + + # Named barrier IDs (same as FP16) + tmem_ptr_sync_bar_id: int = 1 + softmax_exchange_bar_id: int = 2 + epilogue_exchange_bar_id: int = 3 + + @property + def all_warp_ids(self) -> Tuple[int, ...]: + return ( + *self.compute_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_tma_k_warp_id, + self.load_tma_v_warp_id, + *self.empty_warp_ids, + ) + + @property + def num_warps(self) -> int: + return len(self.all_warp_ids) + + @property + def threads_per_cta(self) -> int: + return self.threads_per_warp * self.num_warps + + @property + def num_compute_warps(self) -> int: + return len(self.compute_warp_ids) + + def make_named_barriers(self) -> Tuple[pipeline.NamedBarrier, ...]: + """Create the named barriers used by the FP8 MLA decode kernel. + + Returns (tmem_ptr_sync_bar, softmax_exchange_sync_bar, epilogue_exchange_sync_bar). + """ + n_compute = self.num_compute_warps + tpw = self.threads_per_warp + + tmem_ptr_sync = pipeline.NamedBarrier( + barrier_id=self.tmem_ptr_sync_bar_id, + num_threads=tpw + tpw * n_compute * 2, + ) + softmax_exchange = pipeline.NamedBarrier( + barrier_id=self.softmax_exchange_bar_id, + num_threads=tpw * n_compute, + ) + epilogue_exchange = pipeline.NamedBarrier( + barrier_id=self.epilogue_exchange_bar_id, + num_threads=tpw * n_compute, + ) + return tmem_ptr_sync, softmax_exchange, epilogue_exchange + + +MLA_DECODE_FP8_SCHEDULE = MLAWarpScheduleFP8() diff --git a/flashinfer/cute_dsl/attention/pipeline_topology.py b/flashinfer/cute_dsl/attention/pipeline_topology.py new file mode 100644 index 0000000000..92464bf26b --- /dev/null +++ b/flashinfer/cute_dsl/attention/pipeline_topology.py @@ -0,0 +1,548 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""PipelineTopology — declarative pipeline graph for attention kernels. + +Replaces the imperative pipeline creation code (~76 lines of +make_pipeline_participants calls) with a declarative graph that can be +swapped between kernel variants (FMHA, decode). + +Mirrors the C++ CUTLASS pattern where pipeline types are declared as +type aliases in the Mainloop collective, and the Kernel creates them. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import List, Tuple, Dict, Any + +import cutlass.pipeline as pipeline +from cutlass.pipeline import Agent, CooperativeGroup, PipelineProducer, PipelineConsumer + +from .warp_schedule import WarpSchedule +from .mla_warp_schedule import MLAWarpSchedule, MLAWarpScheduleFP8 + + +class PipelineType(enum.Enum): + """Pipeline types available in the CuTe DSL pipeline library. + + Thread count rules per type: + - TMA_UMMA: leader-only (len(warps)) for both producer and consumer + - UMMA_ASYNC: leader-only for producer, all-threads for consumer + - ASYNC_UMMA: all-threads for producer, leader-only for consumer (reverse of UMMA_ASYNC) + - ASYNC: all-threads for both producer and consumer + - CP_ASYNC: all-threads for both (cpasync-based, no cta_layout_vmnk/tx_count) + """ + + TMA_UMMA = "PipelineTmaUmma" + UMMA_ASYNC = "PipelineUmmaAsync" + ASYNC_UMMA = "PipelineAsyncUmma" + ASYNC = "PipelineAsync" + CP_ASYNC = "PipelineCpAsync" + + @property + def cutlass_type(self): + _map = { + PipelineType.TMA_UMMA: pipeline.PipelineTmaUmma, + PipelineType.UMMA_ASYNC: pipeline.PipelineUmmaAsync, + PipelineType.ASYNC_UMMA: pipeline.PipelineAsyncUmma, + PipelineType.ASYNC: pipeline.PipelineAsync, + PipelineType.CP_ASYNC: pipeline.PipelineCpAsync, + } + return _map[self] + + @property + def needs_cta_layout(self) -> bool: + """Whether this pipeline type accepts cta_layout_vmnk for multi-CTA clusters.""" + return self in ( + PipelineType.TMA_UMMA, + PipelineType.UMMA_ASYNC, + PipelineType.ASYNC_UMMA, + ) + + def producer_thread_count(self, num_warps: int, threads_per_warp: int) -> int: + if self in (PipelineType.TMA_UMMA, PipelineType.UMMA_ASYNC): + return num_warps + return threads_per_warp * num_warps + + def consumer_thread_count(self, num_warps: int, threads_per_warp: int) -> int: + if self in (PipelineType.TMA_UMMA, PipelineType.ASYNC_UMMA): + return num_warps + return threads_per_warp * num_warps + + +@dataclass(frozen=True) +class PipelineEdge: + """Describes a single pipeline in the topology. + + Each edge connects a producer role to a consumer role with a specific + pipeline type and stage count. + + When cluster_scale > 1, the all-thread side of + UMMA_ASYNC / ASYNC_UMMA pipelines multiplies its thread count by + cluster_scale. TMA_UMMA pipelines are unaffected (leader-only on both sides). + """ + + name: str + pipeline_type: PipelineType + stages: int + producer_warp_ids: Tuple[int, ...] + consumer_warp_ids: Tuple[int, ...] + tx_count_key: str | None = None + cluster_scale: int = 1 + + @property + def barrier_field_name(self) -> str: + return f"{self.name}_mbar_ptr" + + @property + def barrier_stages(self) -> int: + """Number of barrier slots needed. PipelineAsync uses 2x stages for the + phase bit; others also use 2x for their internal bookkeeping.""" + if self.pipeline_type == PipelineType.ASYNC: + return self.stages * 2 + return self.stages * 2 + + +@dataclass +class PipelineTopology: + """Declarative specification of the pipeline graph for an attention kernel. + + Contains all the PipelineEdge definitions. A factory method creates + all pipeline participants from barrier storage and tx_count values. + """ + + edges: List[PipelineEdge] = field(default_factory=list) + + def edge_names(self) -> List[str]: + return [e.name for e in self.edges] + + def get_edge(self, name: str) -> PipelineEdge: + for e in self.edges: + if e.name == name: + return e + raise KeyError(f"No pipeline edge named '{name}'") + + def create_pipelines( + self, + barrier_ptrs: Dict[str, Any], + tx_counts: Dict[str, int], + threads_per_warp: int, + cta_layout_vmnk: Any = None, + ) -> Dict[str, Tuple[PipelineProducer, PipelineConsumer]]: + """Create all pipeline producer/consumer pairs from the topology. + + :param barrier_ptrs: Map from edge name to barrier storage pointer. + :param tx_counts: Map from tx_count_key to byte count (for TMA pipelines). + :param threads_per_warp: Threads per warp (typically 32). + :param cta_layout_vmnk: CTA layout for multi-CTA clusters (None for single-CTA). + Required by TMA_UMMA, UMMA_ASYNC, and ASYNC_UMMA when cluster_shape > 1. + :returns: Dict mapping edge name to (producer, consumer) tuple. + """ + result = {} + for edge in self.edges: + prod_threads = edge.pipeline_type.producer_thread_count( + len(edge.producer_warp_ids), threads_per_warp + ) + cons_threads = edge.pipeline_type.consumer_thread_count( + len(edge.consumer_warp_ids), threads_per_warp + ) + + # Apply cluster_scale to the all-threads side of asymmetric pipelines + if edge.cluster_scale > 1: + pt = edge.pipeline_type + if pt == PipelineType.UMMA_ASYNC: + cons_threads *= edge.cluster_scale + elif pt == PipelineType.ASYNC_UMMA: + prod_threads *= edge.cluster_scale + elif pt in (PipelineType.ASYNC, PipelineType.CP_ASYNC): + prod_threads *= edge.cluster_scale + cons_threads *= edge.cluster_scale + + create_kwargs = { + "barrier_storage": barrier_ptrs[edge.name], + "num_stages": edge.stages, + "producer_group": CooperativeGroup(Agent.Thread, prod_threads), + "consumer_group": CooperativeGroup(Agent.Thread, cons_threads), + "defer_sync": True, + } + if edge.tx_count_key is not None: + create_kwargs["tx_count"] = tx_counts[edge.tx_count_key] + if cta_layout_vmnk is not None and edge.pipeline_type.needs_cta_layout: + create_kwargs["cta_layout_vmnk"] = cta_layout_vmnk + + pipe = edge.pipeline_type.cutlass_type.create(**create_kwargs) + result[edge.name] = pipe.make_participants() + return result + + +def make_prefill_topology( + schedule: WarpSchedule, + q_stages: int = 2, + kv_stages: int = 3, + mma_softmax_stages: int = 1, + softmax_corr_stages: int = 1, + mma_corr_stages: int = 2, + epi_stages: int = 2, +) -> PipelineTopology: + """Build the pipeline topology for FMHA prefill. + + The prefill kernel has 9 pipelines connecting 6 warp roles:: + + Load --[load_q]--> MMA --[mma_s0]--> Softmax0 --[s0_corr]--> Correction --[corr_epi]--> Epilogue + --[load_kv]-> --[mma_s1]--> Softmax1 --[s1_corr]--> + --[mma_corr]------------------------------> + Softmax0 --[s0_s1_seq]--> Softmax1 + + :param schedule: Warp schedule defining warp role assignments. + :param kv_stages: Stage count for KV pipeline (3 for fp16/bf16, 4 for fp8). + """ + s = schedule + load = (s.load_warp_id,) + mma = (s.mma_warp_id,) + s0 = s.softmax0_warp_ids + s1 = s.softmax1_warp_ids + corr = s.correction_warp_ids + epi = (s.epilogue_warp_id,) + + return PipelineTopology( + edges=[ + PipelineEdge( + "load_q", + PipelineType.TMA_UMMA, + stages=q_stages, + producer_warp_ids=load, + consumer_warp_ids=mma, + tx_count_key="q", + ), + PipelineEdge( + "load_kv", + PipelineType.TMA_UMMA, + stages=kv_stages, + producer_warp_ids=load, + consumer_warp_ids=mma, + tx_count_key="kv", + ), + PipelineEdge( + "mma_s0", + PipelineType.UMMA_ASYNC, + stages=mma_softmax_stages, + producer_warp_ids=mma, + consumer_warp_ids=s0, + ), + PipelineEdge( + "mma_s1", + PipelineType.UMMA_ASYNC, + stages=mma_softmax_stages, + producer_warp_ids=mma, + consumer_warp_ids=s1, + ), + PipelineEdge( + "s0_corr", + PipelineType.ASYNC, + stages=softmax_corr_stages, + producer_warp_ids=s0, + consumer_warp_ids=corr, + ), + PipelineEdge( + "s1_corr", + PipelineType.ASYNC, + stages=softmax_corr_stages, + producer_warp_ids=s1, + consumer_warp_ids=corr, + ), + PipelineEdge( + "corr_epi", + PipelineType.ASYNC, + stages=epi_stages, + producer_warp_ids=corr, + consumer_warp_ids=epi, + ), + PipelineEdge( + "mma_corr", + PipelineType.UMMA_ASYNC, + stages=mma_corr_stages, + producer_warp_ids=mma, + consumer_warp_ids=corr, + ), + # Softmax1 must wait for softmax0's row-max update before processing + # its KV tile — online softmax requires sequential row-max propagation + # between the two softmax warpgroups. + PipelineEdge( + "s0_s1_sequence", + PipelineType.ASYNC, + stages=1, + producer_warp_ids=s0, + consumer_warp_ids=s1, + ), + ] + ) + + +def make_prefill_topology_transform( + schedule: WarpSchedule, + q_stages: int = 2, + kv_stages: int = 3, + mma_softmax_stages: int = 1, + epi_stages: int = 2, +) -> PipelineTopology: + """Build the pipeline topology for FMHA prefill with logits_transform variants. + + No correction warp: softmax warps perform the epilog (TMEM->scale->SMEM) + after their KV loop, then signal the epilogue warp directly. + + 7 pipelines connecting 5 warp roles:: + + Load --[load_q]--> MMA --[mma_s0]--> Softmax0 --[s0_epi]--> Epilogue + --[load_kv]-> --[mma_s1]--> Softmax1 --[s1_epi]--> + Softmax0 --[s0_s1_seq]--> Softmax1 + + :param schedule: Warp schedule defining warp role assignments. + :param kv_stages: Stage count for KV pipeline (3 for fp16/bf16, 4 for fp8). + """ + s = schedule + load = (s.load_warp_id,) + mma = (s.mma_warp_id,) + s0 = s.softmax0_warp_ids + s1 = s.softmax1_warp_ids + epi = (s.epilogue_warp_id,) + + return PipelineTopology( + edges=[ + PipelineEdge( + "load_q", + PipelineType.TMA_UMMA, + stages=q_stages, + producer_warp_ids=load, + consumer_warp_ids=mma, + tx_count_key="q", + ), + PipelineEdge( + "load_kv", + PipelineType.TMA_UMMA, + stages=kv_stages, + producer_warp_ids=load, + consumer_warp_ids=mma, + tx_count_key="kv", + ), + PipelineEdge( + "mma_s0", + PipelineType.UMMA_ASYNC, + stages=mma_softmax_stages, + producer_warp_ids=mma, + consumer_warp_ids=s0, + ), + PipelineEdge( + "mma_s1", + PipelineType.UMMA_ASYNC, + stages=mma_softmax_stages, + producer_warp_ids=mma, + consumer_warp_ids=s1, + ), + PipelineEdge( + "s0_epi", + PipelineType.ASYNC, + stages=epi_stages, + producer_warp_ids=s0, + consumer_warp_ids=epi, + ), + PipelineEdge( + "s1_epi", + PipelineType.ASYNC, + stages=epi_stages, + producer_warp_ids=s1, + consumer_warp_ids=epi, + ), + PipelineEdge( + "s0_s1_sequence", + PipelineType.ASYNC, + stages=1, + producer_warp_ids=s0, + consumer_warp_ids=s1, + ), + ] + ) + + +def make_mla_topology( + schedule: MLAWarpSchedule, + load_q_stages: int = 1, + load_kv_stages: int = 15, + mma_s_stages: int = 2, + p_mma_stages: int = 2, + p_cor_stages: int = 2, + mma_o_stages: int = 1, + load_pt_stages: int = 4, + cluster_scale: int = 2, +) -> PipelineTopology: + """Build the pipeline topology for MLA decode. + + 7 pipelines connecting 5 warp roles:: + + PT_Load --[load_pt]--> TMA_Load --[load_q]---> MMA --[mma_s]--> Compute --[p_cor]--> Correction + --[load_kv]--> <--[p_mma]-- + --[mma_o]---------------> + + :param schedule: MLA warp schedule defining warp role assignments. + :param cluster_scale: Multiplier for cluster-scaled consumer/producer thread counts. + """ + s = schedule + load_tma = (s.load_tma_warp_id,) + load_pt = (s.load_pt_warp_id,) + mma = (s.mma_warp_id,) + compute = s.compute_warp_ids + correction = s.correction_warp_ids + + return PipelineTopology( + edges=[ + PipelineEdge( + "load_q", + PipelineType.TMA_UMMA, + stages=load_q_stages, + producer_warp_ids=load_tma, + consumer_warp_ids=mma, + tx_count_key="q", + ), + PipelineEdge( + "load_kv", + PipelineType.TMA_UMMA, + stages=load_kv_stages, + producer_warp_ids=load_tma, + consumer_warp_ids=mma, + tx_count_key="kv", + ), + PipelineEdge( + "mma_s", + PipelineType.UMMA_ASYNC, + stages=mma_s_stages, + producer_warp_ids=mma, + consumer_warp_ids=compute, + cluster_scale=cluster_scale, + ), + PipelineEdge( + "p_mma", + PipelineType.ASYNC_UMMA, + stages=p_mma_stages, + producer_warp_ids=compute, + consumer_warp_ids=mma, + cluster_scale=cluster_scale, + ), + PipelineEdge( + "p_cor", + PipelineType.ASYNC, + stages=p_cor_stages, + producer_warp_ids=compute, + consumer_warp_ids=correction, + ), + PipelineEdge( + "mma_o", + PipelineType.UMMA_ASYNC, + stages=mma_o_stages, + producer_warp_ids=mma, + consumer_warp_ids=correction, + cluster_scale=cluster_scale, + ), + PipelineEdge( + "load_pt", + PipelineType.CP_ASYNC, + stages=load_pt_stages, + producer_warp_ids=load_pt, + consumer_warp_ids=load_tma, + ), + ] + ) + + +def make_mla_fp8_topology( + schedule: MLAWarpScheduleFP8, + load_q_stages: int = 1, + load_k_stages: int = 3, + load_v_stages: int = 2, + mma_s_stages: int = 2, + p_mma_stages: int = 2, + p_cor_stages: int = 2, + mma_o_stages: int = 2, + cluster_scale: int = 2, +) -> PipelineTopology: + """Build the pipeline topology for FP8 MLA decode. + + 7 pipelines connecting 5 warp roles (no page-table pipeline):: + + TMA_K_Load --[load_q]---> MMA --[mma_s]--> Compute --[p_cor]--> Correction + --[load_k]--> <--[p_mma]-- + TMA_V_Load --[load_v]--> --[mma_o]---------------> + + FP8 splits the unified load_kv into separate load_k and load_v pipelines + with dedicated TMA loader warps, and removes the page-table pipeline + (page indices are read directly from global memory). + + :param schedule: FP8 MLA warp schedule defining warp role assignments. + :param cluster_scale: Multiplier for cluster-scaled consumer/producer thread counts. + """ + s = schedule + load_k = (s.load_tma_k_warp_id,) + load_v = (s.load_tma_v_warp_id,) + mma = (s.mma_warp_id,) + compute = s.compute_warp_ids + correction = s.correction_warp_ids + + return PipelineTopology( + edges=[ + PipelineEdge( + "load_q", + PipelineType.TMA_UMMA, + stages=load_q_stages, + producer_warp_ids=load_k, + consumer_warp_ids=mma, + tx_count_key="q", + ), + PipelineEdge( + "load_k", + PipelineType.TMA_UMMA, + stages=load_k_stages, + producer_warp_ids=load_k, + consumer_warp_ids=mma, + tx_count_key="kv", + ), + PipelineEdge( + "load_v", + PipelineType.TMA_UMMA, + stages=load_v_stages, + producer_warp_ids=load_v, + consumer_warp_ids=mma, + tx_count_key="vc", + ), + PipelineEdge( + "mma_s", + PipelineType.UMMA_ASYNC, + stages=mma_s_stages, + producer_warp_ids=mma, + consumer_warp_ids=compute, + cluster_scale=cluster_scale, + ), + PipelineEdge( + "p_mma", + PipelineType.ASYNC_UMMA, + stages=p_mma_stages, + producer_warp_ids=compute, + consumer_warp_ids=mma, + cluster_scale=cluster_scale, + ), + PipelineEdge( + "p_cor", + PipelineType.ASYNC, + stages=p_cor_stages, + producer_warp_ids=compute, + consumer_warp_ids=correction, + ), + PipelineEdge( + "mma_o", + PipelineType.UMMA_ASYNC, + stages=mma_o_stages, + producer_warp_ids=mma, + consumer_warp_ids=correction, + cluster_scale=cluster_scale, + ), + ] + ) diff --git a/flashinfer/cute_dsl/attention/prefill.py b/flashinfer/cute_dsl/attention/prefill.py new file mode 100644 index 0000000000..76d96b04fb --- /dev/null +++ b/flashinfer/cute_dsl/attention/prefill.py @@ -0,0 +1,650 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Tuple + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils as utils +from cutlass.cute.typing import Int32, Float32 + +from .config import AttentionConfig, AttentionFusion +from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE, PREFILL_TRANSFORM_SCHEDULE +from .mainloop_spec import make_prefill_mainloop_spec +from .collective_builder import build_fmha_launch_params +from .scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler_params, +) +from .roles.softmax import SoftmaxRole +from .roles.correction import CorrectionRole +from .roles.epilogue import EpilogueRole +from .roles.loader_tma import LoaderRole +from .roles.mma import MmaRole + +import warnings + +warnings.filterwarnings( + "ignore", + message="This loop is no longer unrolled and may cause performance regression", +) + +"""Blackwell SM100 fused multi-head attention (FMHA) kernel using CuTe DSL. + +Warp-specialized persistent kernel with TMA loads/stores, pipelined QK and PV +MMA stages, online softmax with correction, and optional causal/sliding-window +masking. Supports fp16, bf16, and fp8 input types. +""" + + +class BlackwellFusedMultiHeadAttentionForward: + def __init__( + self, + config: AttentionConfig, + fusion: AttentionFusion | None = None, + warp_schedule: WarpSchedule | None = None, + ): + """Initializes a Blackwell Fused Multi-Head Attention (FMHA) kernel. + + :param config: Core attention configuration (dtypes, tile shapes, mode). + :param fusion: Optional customization callbacks (logits/output transforms, sinks). + :param warp_schedule: Warp role assignment and register budgets. Defaults to PREFILL_SCHEDULE. + """ + + self.config = config + self.fusion = fusion if fusion is not None else AttentionFusion() + self.has_logits_transform = self.fusion.variant.has_logits_transform + + if warp_schedule is not None: + self.schedule = warp_schedule + elif self.has_logits_transform: + self.schedule = PREFILL_TRANSFORM_SCHEDULE + else: + self.schedule = PREFILL_SCHEDULE + + self.mainloop = make_prefill_mainloop_spec( + config, + self.schedule, + self.has_logits_transform, + ) + self.tmem = self.mainloop.tmem_layout + + @cute.jit + def __call__( + self, + q_in: cute.Tensor, + k_in: cute.Tensor, + v_in: cute.Tensor, + o_in: cute.Tensor, + problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32], + cum_seqlen_q: cute.Tensor | None, + s_q_all: Int32, + cum_seqlen_k: cute.Tensor | None, + s_k_all: Int32, + scale_softmax_log2: Float32, + scale_output: Float32, + params_in: cute.Tensor | None, + stream, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + :param q_in: The query tensor (NHD layout) + :param k_in: The key tensor (NHD layout) + :param v_in: The value tensor (NHD layout) + :param o_in: The output tensor (NHD layout, with padding before data pointer) + :param problem_size: ``(b, s_q, s_k, h_q, h_k, d)`` + :param cum_seqlen_q: Cumulative query sequence lengths, or None + :param cum_seqlen_k: Cumulative KV sequence lengths, or None + :param scale_softmax_log2: ``log2(e) * sm_scale`` + :param scale_output: Output scaling factor + :param params_in: Variant runtime data tensor, or None + :param stream: CUDA stream + """ + b, s_q, s_k, h_q, h_k, d = problem_size + h_r = h_q // h_k + + o_offset = -s_q * d * h_r * h_k + b_q = 1 + b_kv = 1 + b_o = s_q * (1 + b) + stride_b_q = 0 + stride_b_kv = 0 + stride_b_o = d * h_r * h_k + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q_all, d, ((h_r, h_k), b_q)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_q)), + ) + q = cute.make_tensor(q_in.iterator, q_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k_all, d, ((h_r, h_k), b_kv)), + stride=(d * h_k, 1, ((0, d), stride_b_kv)), + ) + k = cute.make_tensor(k_in.iterator, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k_all, ((h_r, h_k), b_kv)), + stride=(1, d * h_k, ((0, d), stride_b_kv)), + ) + v = cute.make_tensor(v_in.iterator, v_layout) + # (s, d, ((h_r, h_k), b)) + o_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_o)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_o)), + ) + o = cute.make_tensor(o_in.iterator + o_offset, o_layout) + + params = ( + cute.make_tensor( + params_in.iterator, + cute.make_layout( + self.fusion.params_shape, + stride=self.fusion.params_strides, + ), + ) + if self.fusion.has_params + else None + ) + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.o_dtype = o.element_type + + self.tile_sched_params, grid = self._compute_grid( + cute.shape((s_q, d, ((h_r, h_k), b))), + self.config.cta_tiler, + self.config.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.o_layout = utils.LayoutEnum.from_tensor(o) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self.mainloop = self.mainloop.resolve(self.q_dtype.width) + + self.softmax_role = SoftmaxRole( + self.config, + self.fusion, + self.tmem, + softmax0_warp_ids=self.schedule.softmax0_warp_ids, + softmax1_warp_ids=self.schedule.softmax1_warp_ids, + threads_per_warp=self.schedule.threads_per_warp, + ) + if cutlass.const_expr(not self.has_logits_transform): + self.correction_role = CorrectionRole( + self.config, + self.fusion, + self.tmem, + correction_warp_ids=self.schedule.correction_warp_ids, + threads_per_warp=self.schedule.threads_per_warp, + ) + self.epilogue_role = EpilogueRole(self.config) + self.loader_role = LoaderRole(self.config) + self.mma_role = MmaRole( + self.config, + tmem_alloc_cols=self.tmem.alloc_cols, + tmem_alloc_sync_bar_id=self.schedule.tmem_alloc_sync_bar_id, + threads_per_warp=self.schedule.threads_per_warp, + has_logits_transform=self.has_logits_transform, + ) + self.softmax_role.set_dtypes(self.q_dtype, self.o_dtype) + + lp = build_fmha_launch_params( + self.mainloop, + q, + k, + v, + o, + self.q_dtype, + self.k_dtype, + self.v_dtype, + self.o_dtype, + self.q_major_mode, + self.k_major_mode, + self.v_major_mode, + self.o_layout, + ) + self.shared_storage = lp.SharedStorage + + smem_bytes = lp.SharedStorage.size_in_bytes() + smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + if cutlass.const_expr(smem_bytes > smem_capacity): + head_dim = self.config.mma_tiler[2] + raise ValueError( + f"SharedStorage requires {smem_bytes} bytes but SM100 provides " + f"{smem_capacity} bytes. Reduce head_dim (currently {head_dim}) " + f"or tile size." + ) + + self.tma_copy_q_bytes = lp.tma_copy_q_bytes + self.tma_copy_kv_bytes = lp.tma_copy_kv_bytes + + if cutlass.const_expr(not self.has_logits_transform): + self.correction_role.set_call_attrs(self.o_dtype, lp.o_layout, lp.epi_tile) + else: + self.softmax_role.set_call_attrs(lp.o_layout, lp.epi_tile) + + self.kernel( + lp.qk_tiled_mma, + lp.pv_tiled_mma, + lp.tma_atom_q, + lp.tma_tensor_q, + lp.tma_atom_k, + lp.tma_tensor_k, + lp.tma_atom_v, + lp.tma_tensor_v, + lp.tma_atom_o, + lp.tma_tensor_o, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + params, + lp.q_smem_layout_staged, + lp.k_smem_layout_staged, + lp.p_tmem_layout_staged, + lp.v_smem_layout_staged, + lp.o_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.schedule.threads_per_cta, 1, 1], + cluster=lp.cluster_shape_mnk, + smem=lp.SharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.jit + def _create_pipelines(self, storage): + """Create all inter-warp pipelines from the topology and storage barriers.""" + barrier_ptrs = { + edge.name: getattr(storage, edge.barrier_field_name).data_ptr() + for edge in self.mainloop.pipeline_topology.edges + } + tx_counts = {"q": self.tma_copy_q_bytes, "kv": self.tma_copy_kv_bytes} + return self.mainloop.pipeline_topology.create_pipelines( + barrier_ptrs, + tx_counts, + self.schedule.threads_per_warp, + ) + + @cute.jit + def _create_mma_fragments( + self, + qk_tiled_mma, + pv_tiled_mma, + sQ, + sK, + sV, + p_tmem_layout_staged, + ): + """Partition MMA operands and create TMEM offset tensors for double-buffered accumulators.""" + qk_thr_mma = qk_tiled_mma.get_slice(0) + pv_thr_mma = pv_tiled_mma.get_slice(0) + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tOrV = pv_thr_mma.make_fragment_B(sV) + + tStS = qk_thr_mma.make_fragment_C( + qk_thr_mma.partition_shape_C( + (self.config.qk_mma_tiler[0], self.config.qk_mma_tiler[1]) + ) + ) + tOtO = pv_thr_mma.make_fragment_C( + pv_thr_mma.partition_shape_C( + (self.config.pv_mma_tiler[0], self.config.pv_mma_tiler[1]) + ) + ) + + tStS0 = cute.make_tensor(tStS.iterator + self.tmem.s0_offset, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.tmem.s1_offset, tStS.layout) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem.o0_offset, tOtO.layout) + tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem.o1_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + p_scale = self.config.qk_acc_dtype.width // self.q_dtype.width + tOrP0 = cute.make_tensor( + tOrP.iterator + p_scale * self.tmem.p0_offset, tOrP.layout + ) + tOrP1 = cute.make_tensor( + tOrP.iterator + p_scale * self.tmem.p1_offset, tOrP.layout + ) + + return ( + qk_thr_mma, + pv_thr_mma, + tSrQ, + tSrK, + tOrV, + tStS, + tStS0, + tStS1, + tOtO0, + tOtO1, + tOrP0, + tOrP1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mO_qdl: cute.Tensor, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + params: cute.Tensor | None, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + p_tmem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + o_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """FMHA device kernel: warp-specialized attention with pipelined TMA loads.""" + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + if warp_idx == self.schedule.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_o) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + pipes = self._create_pipelines(storage) + load_q_producer, load_q_consumer = pipes["load_q"] + load_kv_producer, load_kv_consumer = pipes["load_kv"] + mma_s0_producer, mma_s0_consumer = pipes["mma_s0"] + mma_s1_producer, mma_s1_consumer = pipes["mma_s1"] + s0_s1_sequence_producer, s0_s1_sequence_consumer = pipes["s0_s1_sequence"] + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + + # Standard path pipelines (correction warp) + s0_corr_producer = s0_corr_consumer = None + s1_corr_producer = s1_corr_consumer = None + corr_epi_producer = corr_epi_consumer = None + mma_corr_producer = mma_corr_consumer = None + if cutlass.const_expr(not self.has_logits_transform): + s0_corr_producer, s0_corr_consumer = pipes["s0_corr"] + s1_corr_producer, s1_corr_consumer = pipes["s1_corr"] + corr_epi_producer, corr_epi_consumer = pipes["corr_epi"] + mma_corr_producer, mma_corr_consumer = pipes["mma_corr"] + + # Transform path pipelines (softmax -> epilogue) + s0_epi_producer = s0_epi_consumer = None + s1_epi_producer = s1_epi_consumer = None + if cutlass.const_expr(self.has_logits_transform): + s0_epi_producer, s0_epi_consumer = pipes["s0_epi"] + s1_epi_producer, s1_epi_consumer = pipes["s1_epi"] + + if warp_idx == self.schedule.empty_warp_id: + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, + self.schedule.tmem_dealloc_arrive_count, + ) + cute.arch.mbarrier_init_fence() + + sQ = storage.sQ.get_tensor( + q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner + ) + sK = storage.sK.get_tensor( + k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner + ) + sV = cute.make_tensor( + cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner), + v_smem_layout_staged.outer, + ) + sO = storage.sO.get_tensor( + o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner + ) + + ( + qk_thr_mma, + pv_thr_mma, + tSrQ, + tSrK, + tOrV, + tStS, + tStS0, + tStS1, + tOtO0, + tOtO1, + tOrP0, + tOrP1, + ) = self._create_mma_fragments( + qk_tiled_mma, + pv_tiled_mma, + sQ, + sK, + sV, + p_tmem_layout_staged, + ) + + cute.arch.barrier( + barrier_id=self.schedule.cta_sync_bar_id, + number_of_threads=self.schedule.threads_per_cta, + ) + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.empty_warp_id: + cute.arch.warpgroup_reg_dealloc(self.schedule.num_regs_empty) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.schedule.num_regs_other) + self.loader_role.run( + qk_thr_mma, + pv_thr_mma, + tma_atom_q, + tma_atom_k, + tma_atom_v, + mQ_qdl, + mK_kdl, + mV_dkl, + sQ, + sK, + sV, + cum_seqlen_q, + cum_seqlen_k, + load_q_producer, + load_kv_producer, + tile_sched_params, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.schedule.num_regs_other) + self.mma_role.run( + qk_tiled_mma, + pv_tiled_mma, + tStS0, + tStS1, + tOtO0, + tOtO1, + tSrQ, + tSrK, + tOrP0, + tOrP1, + tOrV, + mQ_qdl.shape[0], + mK_kdl.shape[0], + cum_seqlen_q, + cum_seqlen_k, + load_q_consumer, + load_kv_consumer, + mma_s0_producer, + mma_s1_producer, + mma_corr_producer, # None for transform path + tile_sched_params, + storage, + tmem_dealloc_mbar_ptr, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.schedule.epilogue_warp_id: + cute.arch.warpgroup_reg_dealloc(self.schedule.num_regs_other) + self.epilogue_role.run( + tma_atom_o, + mO_qdl, + sO, + cum_seqlen_q, + corr_epi_consumer, # None for transform path + s0_epi_consumer, # None for standard path + s1_epi_consumer, # None for standard path + tile_sched_params, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax0 + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.schedule.softmax1_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.schedule.num_regs_softmax) + + self.softmax_role.run( + stage=0, + seqlen_q=mQ_qdl.shape[0], + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, + scale_softmax_log2=scale_softmax_log2, + scale_output=scale_output, + qk_thr_mma=qk_thr_mma, + pv_thr_mma=pv_thr_mma, + tStS=tStS, + tStSi=tStS0, + tOtO=tOtO0, + sO=sO[None, None, 0] if self.has_logits_transform else None, + params=params, + mma_si_consumer=mma_s0_consumer, + si_corr_producer=s0_corr_producer, + si_epi_producer=s0_epi_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, + tile_sched_params=tile_sched_params, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax1 + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.schedule.softmax1_warp_ids[0] + and warp_idx < self.schedule.softmax1_upper_warp_id + ): + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.schedule.num_regs_softmax) + + self.softmax_role.run( + stage=1, + seqlen_q=mQ_qdl.shape[0], + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, + scale_softmax_log2=scale_softmax_log2, + scale_output=scale_output, + qk_thr_mma=qk_thr_mma, + pv_thr_mma=pv_thr_mma, + tStS=tStS, + tStSi=tStS1, + tOtO=tOtO1, + sO=sO[None, None, 1] if self.has_logits_transform else None, + params=params, + mma_si_consumer=mma_s1_consumer, + si_corr_producer=s1_corr_producer, + si_epi_producer=s1_epi_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, + tile_sched_params=tile_sched_params, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if cutlass.const_expr(not self.has_logits_transform): + if ( + warp_idx >= self.schedule.softmax1_upper_warp_id + and warp_idx < self.schedule.mma_warp_id + ): + cute.arch.warpgroup_reg_dealloc(self.schedule.num_regs_correction) + self.correction_role.run( + qk_thr_mma, + pv_thr_mma, + tStS, + tOtO0, + tOtO1, + sO, + mQ_qdl.shape[0], + mK_kdl.shape[0], + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + s0_corr_consumer, + s1_corr_consumer, + mma_corr_consumer, + corr_epi_producer, + tile_sched_params, + tmem_dealloc_mbar_ptr, + ) + return + + @staticmethod + def _compute_grid( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, + ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + tile_sched_params = create_fmha_static_tile_scheduler_params( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid diff --git a/flashinfer/cute_dsl/attention/roles/__init__.py b/flashinfer/cute_dsl/attention/roles/__init__.py new file mode 100644 index 0000000000..8c08330c7f --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# FMHA prefill roles +from .softmax import SoftmaxRole +from .correction import CorrectionRole +from .epilogue import EpilogueRole +from .loader_tma import LoaderRole +from .mma import MmaRole + +# MLA decode roles +from .mla_pt_loader import MLAPageTableLoaderRole +from .mla_loader import MLALoaderRole +from .mla_mma import MLAMmaRole +from .mla_compute import MLAComputeRole +from .mla_correction import MLACorrectionRole diff --git a/flashinfer/cute_dsl/attention/roles/correction.py b/flashinfer/cute_dsl/attention/roles/correction.py new file mode 100644 index 0000000000..824269e0d0 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/correction.py @@ -0,0 +1,471 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""CorrectionRole — output rescaling, epilogue, and orchestration for attention kernels. + +Handles: +- Orchestration loop: pipeline sync with softmax/MMA, scale computation +- Rescaling partial output when row-max changes across KV tiles +- Final scaling, type conversion, optional output transform, SMEM write + +Extracted from BlackwellFusedMultiHeadAttentionForward correction warp section. +""" + +from typing import Optional, Type + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32, Float32 + +from cutlass.pipeline import PipelineProducer, PipelineConsumer + +from ..config import AttentionConfig, AttentionFusion +from ..tmem_layout import TmemLayout +from ..fusion.mask import get_trip_count +from ..scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, +) + + +class CorrectionRole: + """Correction warp group for attention kernels. + + Created from AttentionConfig + AttentionFusion + TmemLayout in the kernel's __init__. + Tensor-type attributes (o_dtype, o_layout, epi_tile) are set later via + set_call_attrs() in __call__. + """ + + def __init__( + self, + config: AttentionConfig, + fusion: AttentionFusion, + tmem: TmemLayout, + correction_warp_ids, + threads_per_warp, + ): + # From config + self.qk_acc_dtype = config.qk_acc_dtype + self.qk_mma_tiler = config.qk_mma_tiler + self.pv_mma_tiler = config.pv_mma_tiler + self.pv_acc_dtype = config.pv_acc_dtype + self.cta_tiler = config.cta_tiler + self.mask_type = config.mask_type + self.window_left = config.window_left + + # From TMEM layout + self.tmem_vec0_offset = tmem.vec0_offset + self.tmem_vec1_offset = tmem.vec1_offset + + # From fusion variant + self.variant = fusion.variant + self.has_logits_transform = fusion.variant.has_logits_transform + self.has_output_transform = fusion.variant.has_output_transform + + # Warp config + self.correction_warp_ids = correction_warp_ids + self.threads_per_warp = threads_per_warp + + # Set later via set_call_attrs() + self.o_dtype: Optional[Type[cutlass.Numeric]] = None + self.o_layout = None + self.epi_tile = None + + def set_call_attrs(self, o_dtype, o_layout, epi_tile): + """Set tensor-type attributes known only at call time.""" + self.o_dtype = o_dtype + self.o_layout = o_layout + self.epi_tile = epi_tile + + @cute.jit + def rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + """ + pv_tiled_mma_shape = ( + self.pv_mma_tiler[0], + self.pv_mma_tiler[1], + ) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + tOcO = thr_mma.partition_C(cO) + + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + + tOtO_i_layout = cute.composition( + tOtO.layout, cute.make_layout((128, corr_tile_size)) + ) + tOcO_i_layout = cute.composition( + tOcO.layout, cute.make_layout((128, corr_tile_size)) + ) + + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i) + + tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i) + + tTMrO = cute.make_fragment( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.pv_acc_dtype + ) + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + ) + + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i) + for j in range(0, cute.size(tTMrO_i), 2): + tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO_i[j], tTMrO_i[j + 1]), + (scale, scale), + ) + cute.copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i) + + @cute.jit + def epilog( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + m: Float32, + d: Float32, + sO: cute.Tensor, + batch_coord: Int32, + head_coord: Int32, + qo_idx_offset: Int32, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + Performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + """ + assert self.o_dtype is not None + assert self.epi_tile is not None + + pv_tiled_mma_shape = ( + self.pv_mma_tiler[0], + self.pv_mma_tiler[1], + ) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + cO_custom = cute.make_identity_tensor(pv_tiled_mma_shape) + + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cO) + tOcO_custom = thr_mma.partition_C(cO_custom) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) + tOcO_custom_i = cute.logical_divide( + tOcO_custom, cute.make_layout((128, corr_tile_size)) + ) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils.get_tmem_load_op( + self.pv_mma_tiler, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_copy_atom, tOtO_i[(None, None), 0] + ) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tTMEM_LOADoO = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + tTMEM_LOADcO_custom = thr_tmem_load.partition_D( + tOcO_custom_i[(None, None), None] + ) + + scale_rcp_d = scale / d if not self.has_logits_transform else scale + rcp_d = 1 / d if m != -Float32.inf else 0.0 + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] + tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] + tTMrO = cute.make_fragment( + tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype + ) + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) + if cutlass.const_expr(not self.has_output_transform): + for j in range(0, cute.size(tTMrO), 2): + tTMrO[j], tTMrO[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO[j], tTMrO[j + 1]), + (scale_rcp_d, scale_rcp_d), + ) + else: + tTMcO_custom = tTMEM_LOADcO_custom[None, 0, 0, i] + for j in range(0, cute.size(tTMrO)): + qo_idx = qo_idx_offset + tTMcO_custom[j][0] + tTMrO[j] = self.variant.transform_output( + tTMrO[j], + batch_coord, + qo_idx, + head_coord, + m, + rcp_d, + scale, + ) + tSMrO = cute.make_fragment(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + @cute.jit + def run( + self, + qk_thr_mma: cute.core.ThrMma, + pv_thr_mma: cute.core.ThrMma, + tStS: cute.Tensor, + tOtO0: cute.Tensor, + tOtO1: cute.Tensor, + sO: cute.Tensor, + seqlen_q_global: Int32, + seqlen_k_global: Int32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + s0_corr_consumer: PipelineConsumer, + s1_corr_consumer: PipelineConsumer, + mma_corr_consumer: PipelineConsumer, + corr_epi_producer: PipelineProducer, + tile_sched_params: FmhaStaticTileSchedulerParams, + tmem_dealloc_mbar_ptr: Int32, + ): + """Correction warp orchestration loop. + + For each work tile, synchronizes with softmax (vec buffers) and MMA + (output partials) pipelines, computes rescaling factors from row-max + changes, delegates to rescale() and epilog(), and signals the epilogue + warp when output is ready. + """ + tidx, _, _ = cute.arch.thread_idx() + + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr_mma.partition_C(cS) + + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + + tStS_vec0 = cute.make_tensor( + tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout + ) + tStS_vec1 = cute.make_tensor( + tStS.iterator + self.tmem_vec1_offset, tStS_vec_layout + ) + + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStS_vec0) + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(thread_idx) + + tTMEM_LOAD_VECtS0 = thr_tmem_load_vec.partition_S(tStS_vec0) + tTMEM_LOAD_VECtS1 = thr_tmem_load_vec.partition_S(tStS_vec1) + tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + head_coord = curr_block_coord[2][0] + qo_idx_offset = curr_block_coord[0] * self.cta_tiler[0] + + seqlen_q_ = seqlen_q_global + seqlen_k = seqlen_k_global + continue_cond = False + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q_ = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q_, + ) + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + # Ignore first signal from softmax as no correction is required + vec0_handle = s0_corr_consumer.wait_and_advance() + vec0_handle.release() + vec1_handle = s1_corr_consumer.wait_and_advance() + seqlen_kv_loop_steps = ( + get_trip_count( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k, + seqlen_q_, + ) + - 1 + ) + for _i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # wait for vec0 (row_wise current max & previous max) + vec0_handle = s0_corr_consumer.wait_and_advance() + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.arch.exp2(scale_) + + # wait for o0 + o0_handle_consumer = mma_corr_consumer.wait_and_advance() + if cutlass.const_expr(not self.has_logits_transform): + self.rescale(pv_thr_mma, tOtO0, scale) + # release vec1 & o0 + vec1_handle.release() + cute.arch.fence_view_async_tmem_store() + o0_handle_consumer.release() + + # wait for vec1 (row_wise current max & previous max) + vec1_handle = s1_corr_consumer.wait_and_advance() + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.arch.exp2(scale_) + + o1_handle_consumer = mma_corr_consumer.wait_and_advance() + if cutlass.const_expr(not self.has_logits_transform): + self.rescale(pv_thr_mma, tOtO1, scale) + vec0_handle.release() + cute.arch.fence_view_async_tmem_store() + o1_handle_consumer.release() + # End of seqlen_corr_loop_steps + vec1_handle.release() + + # wait for vec0 (row_wise global sum) + vec0_handle = s0_corr_consumer.wait_and_advance() + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + vec0_handle.release() + # wait for o0 + o0_handle_consumer = mma_corr_consumer.wait_and_advance() + o0_final_handle = corr_epi_producer.acquire_and_advance() + + epilogue_scale = scale_output + d = tTMEM_LOAD_VECrS[0] # row sum + m = tTMEM_LOAD_VECrS[1] # row max + self.epilog( + pv_thr_mma, + tOtO0, + epilogue_scale, + m, + d, + sO[None, None, 0], + batch_coord, + head_coord, + qo_idx_offset, + ) + o0_handle_consumer.release() + o0_final_handle.commit() + + # wait for vec1 (row_wise global sum) + vec1_handle = s1_corr_consumer.wait_and_advance() + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + vec1_handle.release() + # wait for o1 + o1_handle_consumer = mma_corr_consumer.wait_and_advance() + o1_final_handle = corr_epi_producer.acquire_and_advance() + + epilogue_scale = scale_output + d = tTMEM_LOAD_VECrS[0] # row sum + m = tTMEM_LOAD_VECrS[1] # row max + self.epilog( + pv_thr_mma, + tOtO1, + epilogue_scale, + m, + d, + sO[None, None, 1], + batch_coord, + head_coord, + qo_idx_offset + self.qk_mma_tiler[0], + ) + o1_handle_consumer.release() + o1_final_handle.commit() + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) diff --git a/flashinfer/cute_dsl/attention/roles/epilogue.py b/flashinfer/cute_dsl/attention/roles/epilogue.py new file mode 100644 index 0000000000..6084988c3f --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/epilogue.py @@ -0,0 +1,181 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""EpilogueOps — TMA store primitives and orchestration for attention output. + +Reusable primitives (pipeline-unaware, for composing new kernel variants): +- partition_output(): partition output global tensor for TMA stores +- store_tile(): issue a single TMA store + commit group + +Orchestration (prefill-specific, uses raw CuTe ops for JIT compatibility): +- run(): O0/O1 double-buffered TMA stores with pipeline sync +""" + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32 + +from cutlass.pipeline import PipelineConsumer + +from ..config import AttentionConfig +from ..scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, +) + + +class EpilogueRole: + """Epilogue warp for attention kernels — TMA stores output to global memory. + + Created from AttentionConfig in the kernel's __init__. + """ + + def __init__(self, config: AttentionConfig): + self.pv_mma_tiler = config.pv_mma_tiler + self.cta_tiler = config.cta_tiler + + # ========================================================================= + # Reusable primitives — for composing new kernel variants + # + # NOTE on CuTe DSL JIT limitations: + # - partition_output(): Returns tensor tuples — CuTe DSL JIT does not + # reliably handle returning tensors from @cute.jit methods. + # - store_tile(): SAFE — takes pre-sliced tensors as arguments, no + # runtime indexing or return values. Used in run() successfully. + # ========================================================================= + + @cute.jit + def partition_output( + self, + tma_atom_o: cute.CopyAtom, + mO_qdl: cute.Tensor, + sO: cute.Tensor, + block_coord: tuple, + ): + """Partition output global tensor for TMA stores. Returns (tOsO, tOgO).""" + gO_qdl = cute.flat_divide(mO_qdl, cute.select(self.pv_mma_tiler, mode=[0, 1])) + gO = gO_qdl[None, None, None, 0, block_coord[2]] + tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + return tOsO, tOgO + + @cute.jit + def store_tile( + self, + tma_atom_o: cute.CopyAtom, + tOsO_slice: cute.Tensor, + tOgO_slice: cute.Tensor, + ): + """Issue a single TMA store from SMEM to GMEM + commit group.""" + cute.copy(tma_atom_o, tOsO_slice, tOgO_slice) + cute.arch.cp_async_bulk_commit_group() + + # ========================================================================= + # Prefill orchestration — proven-correct inline implementation + # ========================================================================= + + @cute.jit + def run( + self, + tma_atom_o: cute.CopyAtom, + mO_qdl: cute.Tensor, + sO: cute.Tensor, + cum_seqlen_q: cute.Tensor | None, + corr_epi_consumer: PipelineConsumer | None, + s0_epi_consumer: PipelineConsumer | None, + s1_epi_consumer: PipelineConsumer | None, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """Epilogue warp orchestration loop (prefill-specific). + + O0/O1 double-buffered TMA stores with pipeline synchronization. + + Standard path: consumes from corr_epi (correction -> epilogue). + Transform path: consumes from s0_epi/s1_epi (softmax -> epilogue). + """ + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mO_qdl.shape[0] + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + if not continue_cond: + curr_block_coord_o = curr_block_coord + mO_qdl_ = mO_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mO = ( + mO_qdl_.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mO_qdl_ = cute.domain_offset(logical_offset_mO, mO_qdl_) + curr_block_coord_o = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], 0), + ) + + o0_coord = 2 * curr_block_coord_o[0] + o1_coord = o0_coord + 1 + gO_qdl = cute.flat_divide( + mO_qdl_, cute.select(self.pv_mma_tiler, mode=[0, 1]) + ) + gO = gO_qdl[None, None, None, 0, curr_block_coord_o[2]] + tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + + if cutlass.const_expr(corr_epi_consumer is not None): + # Standard path: O0/O1 from correction warp + o0_handle_consumer = corr_epi_consumer.wait_and_advance() + self.store_tile(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) + + o1_handle_consumer = corr_epi_consumer.wait_and_advance() + self.store_tile(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) + + cute.arch.cp_async_bulk_wait_group(1, read=True) + o0_handle_consumer.release() + cute.arch.cp_async_bulk_wait_group(0, read=True) + o1_handle_consumer.release() + else: + # Transform path: O0 from softmax0, O1 from softmax1 + o0_handle_consumer = s0_epi_consumer.wait_and_advance() + self.store_tile(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) + + o1_handle_consumer = s1_epi_consumer.wait_and_advance() + self.store_tile(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) + + cute.arch.cp_async_bulk_wait_group(1, read=True) + o0_handle_consumer.release() + cute.arch.cp_async_bulk_wait_group(0, read=True) + o1_handle_consumer.release() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() diff --git a/flashinfer/cute_dsl/attention/roles/loader_tma.py b/flashinfer/cute_dsl/attention/roles/loader_tma.py new file mode 100644 index 0000000000..0932881db1 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/loader_tma.py @@ -0,0 +1,334 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""LoaderOps — TMA load primitives and orchestration for attention kernels. + +Reusable primitives (pipeline-unaware, for composing new kernel variants): +- partition_q(): partition Q global tensor for TMA loads +- partition_k(): partition K global tensor for TMA loads +- partition_v(): partition V global tensor for TMA loads +- load_tile(): issue a single TMA load with barrier + +Orchestration (prefill-specific, uses raw CuTe ops for JIT compatibility): +- run(): Q0/Q1 double-buffered loads with KV streaming +""" + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32 + +from cutlass.pipeline import PipelineProducer + +from ..config import AttentionConfig +from ..fusion.mask import get_trip_count, get_kv_start_block_idx +from ..scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, +) + + +class LoaderRole: + """Loader warp for attention kernels — TMA loads Q, K, V into SMEM. + + Created from AttentionConfig in the kernel's __init__. + """ + + def __init__(self, config: AttentionConfig): + self.cta_tiler = config.cta_tiler + self.qk_mma_tiler = config.qk_mma_tiler + self.pv_mma_tiler = config.pv_mma_tiler + self.mask_type = config.mask_type + self.window_left = config.window_left + + # ========================================================================= + # Reusable primitives — for composing new kernel variants + # + # NOTE on CuTe DSL JIT limitations: + # - partition_q/k/v(): Return tensor tuples — CuTe DSL JIT does not + # reliably handle returning tensors from @cute.jit methods. + # - load_tile(): Uses runtime indexing (handle.index) to create tensor + # views internally — causes correctness issues in CuTe DSL JIT. + # These primitives document the intended decomposition but cannot be + # used inside run() until CuTe DSL JIT support improves. Use the + # inline patterns in run() as the working reference. + # ========================================================================= + + @cute.jit + def partition_q( + self, + qk_thr_mma: cute.core.ThrMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + sQ: cute.Tensor, + block_coord: tuple, + ): + """Partition Q global tensor for TMA loads. Returns (tQsQ, tQgQ).""" + gQ_qdl = cute.flat_divide(mQ_qdl, cute.select(self.qk_mma_tiler, mode=[0, 2])) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + tQgQ = tQgQ_qdl[None, None, 0, block_coord[2]] + return tQsQ, tQgQ + + @cute.jit + def partition_k( + self, + qk_thr_mma: cute.core.ThrMma, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + sK: cute.Tensor, + block_coord: tuple, + ): + """Partition K global tensor for TMA loads. Returns (tKsK, tKgK).""" + gK_kdl = cute.flat_divide(mK_kdl, cute.select(self.qk_mma_tiler, mode=[1, 2])) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + tKgK = tKgK_kdl[None, None, 0, block_coord[2]] + return tKsK, tKgK + + @cute.jit + def partition_v( + self, + pv_thr_mma: cute.core.ThrMma, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + sV: cute.Tensor, + block_coord: tuple, + ): + """Partition V global tensor for TMA loads. Returns (tVsV, tVgV).""" + gV_dkl = cute.flat_divide(mV_dkl, cute.select(self.pv_mma_tiler, mode=[1, 2])) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tVgV = tVgV_dkl[None, 0, None, block_coord[2]] + return tVsV, tVgV + + @cute.jit + def load_tile( + self, + tma_atom: cute.CopyAtom, + src_global: cute.Tensor, + dst_smem: cute.Tensor, + producer: PipelineProducer, + ): + """Issue a single TMA load into SMEM with pipeline barrier.""" + handle = producer.acquire_and_advance() + cute.copy( + tma_atom, + src_global, + dst_smem[None, handle.index], + tma_bar_ptr=handle.barrier, + ) + + # ========================================================================= + # Prefill orchestration — proven-correct inline implementation + # ========================================================================= + + @cute.jit + def run( + self, + qk_thr_mma: cute.core.ThrMma, + pv_thr_mma: cute.core.ThrMma, + tma_atom_q: cute.CopyAtom, + tma_atom_k: cute.CopyAtom, + tma_atom_v: cute.CopyAtom, + mQ_qdl: cute.Tensor, + mK_kdl: cute.Tensor, + mV_dkl: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sV: cute.Tensor, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + load_q_producer: PipelineProducer, + load_kv_producer: PipelineProducer, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """Loader warp orchestration loop (prefill-specific). + + Q0/Q1 double-buffered loads with KV tile streaming. + """ + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + if not continue_cond: + mQ_qdl_ = mQ_qdl + mK_kdl_ = mK_kdl + mV_dkl_ = mV_dkl + seqlen_k = mK_kdl.shape[0] + curr_block_coord_q = curr_block_coord + curr_block_coord_kv = curr_block_coord + + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mQ = (cuseqlen_q, 0, (0, 0)) + mQ_qdl_ = cute.domain_offset(logical_offset_mQ, mQ_qdl) + curr_block_coord_q = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) + + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + logical_offset_mK = (cuseqlen_k, 0, (0, 0)) + logical_offset_mV = (0, cuseqlen_k, (0, 0)) + mK_kdl_ = cute.domain_offset(logical_offset_mK, mK_kdl) + mV_dkl_ = cute.domain_offset(logical_offset_mV, mV_dkl) + curr_block_coord_kv = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) + + # Local tile partition global tensors + gQ_qdl = cute.flat_divide( + mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2]) + ) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + 0, + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + tQgQ = tQgQ_qdl[None, None, 0, curr_block_coord_q[2]] + + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + 0, + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + tKgK = tKgK_kdl[None, None, 0, curr_block_coord_kv[2]] + + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + 0, + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tVgV = tVgV_dkl[None, 0, None, curr_block_coord_kv[2]] + + # Q0 + q0_coord = 2 * curr_block_coord_q[0] + q0_handle_producer = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, q0_coord], + tQsQ[None, q0_handle_producer.index], + tma_bar_ptr=q0_handle_producer.barrier, + ) + # K0 + kv_coord = get_kv_start_block_idx( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k, + seqlen_q, + ) + k_handle_producer = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, k_handle_producer.index], + tma_bar_ptr=k_handle_producer.barrier, + ) + # Q1 + q1_coord = q0_coord + 1 + q1_handle_producer = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, q1_coord], + tQsQ[None, q1_handle_producer.index], + tma_bar_ptr=q1_handle_producer.barrier, + ) + # V0 + v_handle_producer = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, v_handle_producer.index], + tma_bar_ptr=v_handle_producer.barrier, + ) + kv_coord += 1 + + seqlen_kv_loop_steps = ( + get_trip_count( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k, + seqlen_q, + ) + - 1 + ) + for _i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + k_handle_producer = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, k_handle_producer.index], + tma_bar_ptr=k_handle_producer.barrier, + ) + # Vi + v_handle_producer = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, v_handle_producer.index], + tma_bar_ptr=v_handle_producer.barrier, + ) + kv_coord += 1 + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() diff --git a/flashinfer/cute_dsl/attention/roles/mla_compute.py b/flashinfer/cute_dsl/attention/roles/mla_compute.py new file mode 100644 index 0000000000..f0f41c5c06 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_compute.py @@ -0,0 +1,620 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLAComputeRole — compute (softmax) warp role for MLA decode kernels. + +Handles: +- Online softmax with row-max tracking, exp2, correction factor +- SM100 vs SM103 architecture dispatch for TMEM load (plain vs fused-reduce) +- SMEM exchange for row-max reduction across warps +- P quantization and SMEM store to feed the PV MMA stage +- Row-sum accumulation using packed f32x2 adds +- Tile scheduler loop with mask applied on the last tile +- Correction metadata exchange to TMEM for correction warps + +Uses handle-passing pattern: state transitions (acquire_and_advance, +wait_and_advance) stay in run(), while ImmutableResourceHandles are passed +to sub-methods where side effects (commit, release) are co-located with +their paired fences. + +Extracted from MLADecodeFP16Kernel.compute / softmax / exchange_p_cor_metadata. +""" + +import math +from typing import Type + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.pipeline import PipelineProducer, PipelineConsumer +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL +from types import SimpleNamespace + +from ..mla_config import MLAConfig +from ..config import AttentionFusion +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, +) + +from ..compat import setmaxregister_increase as _setmaxregister_increase + + +class MLAComputeRole: + """Compute (softmax) warp role for MLA decode kernels. + + Owns the tile-scheduler loop and performs online softmax, producing + quantized P tiles in SMEM and correction metadata in TMEM. + + Optionally integrates AttentionVariant hooks (score_mod, update_statistics) + via the fusion parameter, mirroring the prefill SoftmaxRole pattern. + """ + + def __init__(self, config: MLAConfig, fusion: AttentionFusion): + self.acc_dtype = config.acc_dtype + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_pv_tiler = config.mma_pv_tiler + self.cluster_shape_mnk = config.cluster_shape_mnk + self.warps_in_n = config.warps_in_n + self.num_compute_warps = config.num_compute_warps + self.threads_per_warp = 32 + self.mma_s_stage = config.mma_s_stage + self.p_mma_stage = config.p_mma_stage + self.p_cor_stage = config.p_cor_stage + self.skip_correction_threshold = config.skip_correction_threshold + self.tmem_o_offset = config.tmem_o_offset + self.correction_factor_offset = config.correction_factor_offset + self.is_var_split_kv = config.is_var_split_kv + + self.variant = fusion.variant + self.has_score_mod = fusion.variant.has_score_mod + self.has_statistics_update = fusion.variant.has_statistics_update + self.has_params = fusion.has_params + + self.softmax_reg_num = 192 + self.softmax_exchange_sync_bar = None + + def set_dtypes(self, q_dtype: Type[cutlass.Numeric]) -> None: + """Set tensor element types discovered at call time.""" + self.q_dtype: Type[cutlass.Numeric] = q_dtype + + def set_barriers(self, softmax_exchange_sync_bar): + """Set named barriers owned by the kernel.""" + self.softmax_exchange_sync_bar = softmax_exchange_sync_bar + + # ------------------------------------------------------------------ + # Tile count helper + # ------------------------------------------------------------------ + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get k_index, k_tile_count, and local split_kv for a work tile.""" + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + # ------------------------------------------------------------------ + # Correction metadata exchange + # ------------------------------------------------------------------ + + @cute.jit + def exchange_p_cor_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + correction_factor: cutlass.Float32, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + row_max_new: cutlass.Float32, + tAcc: cute.Tensor, + tidx: cutlass.Int32, + p_cor_handle, + ): + """Write correction metadata to TMEM for the correction warps. + + Commits the p_cor handle after fence_view_async_tmem_store, + co-locating the fence+commit pair. Returns the updated row_max_new. + """ + no_correction = 0 + if ( + row_max_new - row_max + ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: + no_correction = 1 + row_max_new = row_max + + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout, + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) + corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) + cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) + tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + rCor[0] = row_sum + rCor[1] = row_max_new + rCor[2] = correction_factor + rCor_int[3] = no_correction + + cute.copy( + corr_tmem_store_tiled_copy, + rCor, + tCor_for_copy[None, None, None, p_cor_handle.index], + ) + cute.arch.fence_view_async_tmem_store() + p_cor_handle.commit() + return row_max_new + + # ------------------------------------------------------------------ + # Softmax (single tile) + # ------------------------------------------------------------------ + + @cute.jit + def softmax( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + mma_s_handle, + p_mma_handle, + p_cor_handle, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + is_last_tile: bool, + is_local_last_tile: cutlass.Boolean, + params: cute.Tensor = None, + ) -> tuple: + """Online softmax for one k-tile. + + Contains the SM100 vs SM103 architecture dispatch for TMEM load, + masking, exp2, row-max reduction with SMEM exchange, P quantization + and SMEM store, row-sum accumulation. + + When an AttentionVariant is configured: + - update_statistics runs before TMEM load (e.g. attention sink) + - score_mod runs after TMEM load + masking, before row_max (e.g. ALiBi) + + Side effects co-located with their paired fences: + - fence_view_async_shared → p_mma_handle.commit() + - fence_view_async_tmem_store → p_cor_handle.commit() (via exchange) + - fence_view_async_tmem_load → mma_s_handle.release() + + Returns (row_max_new, row_sum, correction_factor). + """ + + # load S from tmem + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tStS = tStS_staged[None, None, None, mma_s_handle.index] + + tAcc = tStS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], + self.mma_qk_tiler[1], + self.mma_qk_tiler[2], + ) + cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + tTR_tAcc = tmem_thr_copy.partition_S(tAcc) + tTR_tS = tmem_thr_copy.partition_D(cS) + + # Inject virtual tokens into running (m, d) before loading scores + if cutlass.const_expr(self.has_statistics_update): + if cutlass.const_expr(self.has_params): + self.variant.params = params + qo_head_idx_for_stats = tTR_tS[0][0] + common_params.cta_m_offset + row_max, row_sum = self.variant.update_statistics( + k_index, + qo_head_idx_for_stats, + row_max, + row_sum, + softmax_params.softmax_scale_log2, + ) + + tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) + + row_max_new = row_max + arch = BaseDSL._get_dsl().get_arch_enum() + if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): + cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if is_last_tile: + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) + + elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + tmem_load_red_atom = cute.make_copy_atom( + tcgen05.copy.LdRed32x32bOp( + tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX + ), + self.acc_dtype, + ) + tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) + tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) + tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) + tTR_tS_red = tmem_red_thr_copy.partition_D(cS) + tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) + tTR_rMax = cute.make_rmem_tensor( + cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), + self.acc_dtype, + ) + cute.copy( + tmem_red_tiled_copy, + tTR_tAcc_red, + (tTR_rAcc_red, tTR_rMax), + ) + tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) + if is_last_tile: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + row_max_new = tTR_rAcc.load().reduce( + cute.ReductionOp.MAX, row_max_new, 0 + ) + else: + row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) + + # Per-element score modification (e.g. ALiBi bias, soft-capping). + # Applied after masking, before row_max finalization. When active, + # row_max must be recomputed from the modified scores. + if cutlass.const_expr(self.has_score_mod): + if cutlass.const_expr(self.has_params and not self.has_statistics_update): + self.variant.params = params + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + qo_head_idx = tTR_tS[i][0] + common_params.cta_m_offset + kv_idx = tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index + tTR_rAcc[i] = self.variant.score_mod( + tTR_rAcc[i], + common_params.batch_idx, + common_params.qo_idx, + kv_idx, + qo_head_idx, + 0, + ) + # Re-apply masking: score_mod may map -inf to a finite value + # (e.g. SoftCapping: cap*tanh(-inf/cap) = -cap). Restore -inf + # for out-of-bounds positions so they get zero softmax weight. + if is_last_tile: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if not cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ): + tTR_rAcc[i] = -self.acc_dtype.inf + row_max_new = row_max + row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) + + # reduce row_max across warps via SMEM exchange when warps_in_n == 2 + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_max_new + assert self.softmax_exchange_sync_bar is not None + self.softmax_exchange_sync_bar.arrive_and_wait() + row_max_new = cute.arch.fmax( + row_max_new, + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ], + ) + + # correction factor + correction_factor = cute.math.exp2( + (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True + ) + # split kv case: exchange metadata before last tile + if cutlass.const_expr(not is_local_last_tile): + row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_handle, + ) + + # exp2 + quantize + fma_b = softmax_params.softmax_scale_log2 + fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 + + for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): + tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c + tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) + + tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) + + tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) + + # store P to SMEM + sP = softmax_params.sP[None, None, None, (None, p_mma_handle.index)] + sP_mk_view = cute.make_tensor( + sP.iterator, + cute.make_layout( + ( + (sP.shape[0][0], sP.shape[1]), + (sP.shape[0][1], sP.shape[2], sP.shape[3]), + ), + stride=( + (sP.stride[0][0], sP.stride[1]), + (sP.stride[0][1], sP.stride[2], sP.stride[3]), + ), + ), + ) + sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) + swizzle_bits = ( + int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 + ) + swizzle_base = 3 if self.q_dtype.width == 16 else 4 + sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) + sP_mk_view = cute.make_tensor( + sP_wo_swizzle_iter, + cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), + ) + universal_copy_bits = 128 + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=universal_copy_bits, + ) + smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) + smem_thr_copy = smem_tiled_copy.get_slice(tidx) + rP_copy_view = smem_thr_copy.retile(tTR_rS) + sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) + + cute.arch.fence_view_async_shared() + p_mma_handle.commit() + + # row_sum accumulation using packed f32x2 to reduce instruction count + row_sum = row_sum * correction_factor + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) + ) + row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + + # split kv case: exchange metadata on last tile + if cutlass.const_expr(is_local_last_tile): + row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_handle, + ) + + cute.arch.fence_view_async_tmem_load() + mma_s_handle.release() + + return ( + row_max_new, + row_sum, + correction_factor, + ) + + # ------------------------------------------------------------------ + # run — top-level entry: pipeline init + tile scheduler loop + # + # State transitions (acquire_and_advance, wait_and_advance) live here. + # ImmutableResourceHandles are passed to sub-methods where side + # effects (commit, release) are co-located with their paired fences. + # ------------------------------------------------------------------ + + @cute.jit + def run( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + tile_sched_params: MLAStaticTileSchedulerParams, + tmem_ptr: cute.Pointer, + mma_s_consumer: PipelineConsumer, + p_mma_producer: PipelineProducer, + p_cor_producer: PipelineProducer, + softmax_smem_exchange: cute.Tensor, + mAccO: cute.Tensor, + mO: cute.Tensor, + mCL: cute.Tensor, + K: cutlass.Int32, + L: cutlass.Int32, + tiled_mma_qk: cute.TiledMma, + sP: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + tmem, + params: cute.Tensor = None, + ): + """Top-level entry for the compute warp role. + + Iterates the tile scheduler and runs softmax for each valid work + tile. State transitions (acquire/wait) happen here; commit/release + are co-located with fences inside softmax/exchange_p_cor_metadata. + """ + _setmaxregister_increase(self.softmax_reg_num) + + tmem.wait_for_alloc() + tmem_ptr_resolved = tmem.retrieve_ptr(self.acc_dtype) + + tidx, _, _ = cute.arch.thread_idx() + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + bidx, _, _ = cute.arch.block_idx() + cta_m_offset = (bidx % self.cluster_shape_mnk[0]) * ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0] + ) + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=L, + tmem_ptr=tmem_ptr_resolved, + tidx=tidx, + batch_idx=blk_coord[2], + qo_idx=blk_coord[1], + cta_m_offset=cta_m_offset, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + softmax_scale_log2=softmax_scale_log2, + ) + k_tile_total = cute.ceil_div( + compute_common_params.K, self.mma_qk_tiler[1] + ) + + row_max = -self.acc_dtype.inf + row_sum = self.acc_dtype(0) + correction_factor = self.acc_dtype(1) + p_cor_handle = p_cor_producer.acquire_and_advance() + + # unmasked tiles + while k_tile_count > 1: + p_mma_handle = p_mma_producer.acquire_and_advance() + mma_s_handle = mma_s_consumer.wait_and_advance() + + ( + row_max, + row_sum, + correction_factor, + ) = self.softmax( + compute_common_params, + compute_softmax_params, + k_index, + mma_s_handle, + p_mma_handle, + p_cor_handle, + row_max, + row_sum, + correction_factor, + False, + False, + params, + ) + + p_cor_handle = p_cor_producer.acquire_and_advance() + + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # last tile (masked) + p_mma_handle = p_mma_producer.acquire_and_advance() + mma_s_handle = mma_s_consumer.wait_and_advance() + + if cutlass.const_expr(mAccO is not None): + ( + row_max, + row_sum, + correction_factor, + ) = self.softmax( + compute_common_params, + compute_softmax_params, + k_index, + mma_s_handle, + p_mma_handle, + p_cor_handle, + row_max, + row_sum, + correction_factor, + k_index == k_tile_total - 1, + True, + params, + ) + else: + ( + row_max, + row_sum, + correction_factor, + ) = self.softmax( + compute_common_params, + compute_softmax_params, + k_index, + mma_s_handle, + p_mma_handle, + p_cor_handle, + row_max, + row_sum, + correction_factor, + True, + True, + params, + ) + + # Trailing sync: acquire() without advance — back-pressure only, + # no data produced. acquire_and_advance() would desync the pipeline + # across persistent kernel work tiles. + p_cor_producer.acquire() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + p_cor_producer.tail() diff --git a/flashinfer/cute_dsl/attention/roles/mla_correction.py b/flashinfer/cute_dsl/attention/roles/mla_correction.py new file mode 100644 index 0000000000..eba662676c --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_correction.py @@ -0,0 +1,617 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLACorrectionRole — rescale + epilogue warp role for MLA decode. + +Owns the tile-scheduler loop for the correction warp, performing: +- Loading correction metadata (row_sum, row_max, correction_factor) from TMEM +- Rescaling partial O accumulator when row-max changes across KV tiles +- Final normalization, dtype conversion, and global memory write (O and LSE) +- Split-KV workspace output vs direct output path selection + +Uses handle-passing pattern: state transitions (wait_and_advance) stay in +run(), while ImmutableResourceHandles are passed to sub-methods where side +effects (release) are co-located with their paired fences. + +Extracted from the monolithic MLA decode kernel's correction warp section. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.pipeline import PipelineConsumer +from types import SimpleNamespace + +from ..mla_config import MLAConfig +from ..config import AttentionFusion +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, +) + + +class MLACorrectionRole: + """Correction warp role for MLA decode kernels. + + Handles output rescaling across KV tiles and final epilogue (normalize, + convert dtype, write O and LSE to global memory or split-KV workspace). + + Optionally integrates AttentionVariant.transform_output via the fusion + parameter, allowing custom output normalization (e.g. AttentionWithSink). + """ + + def __init__( + self, + config: MLAConfig, + fusion: AttentionFusion, + v_dtype=None, + o_dtype=None, + ): + self.acc_dtype = config.acc_dtype + self.lse_dtype = config.lse_dtype + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_pv_tiler = config.mma_pv_tiler + self.cluster_shape_mnk = config.cluster_shape_mnk + self.warps_in_n = config.warps_in_n + self.num_compute_warps = config.num_compute_warps + self.threads_per_warp = 32 + self.p_cor_stage = config.p_cor_stage + self.mma_o_stage = config.mma_o_stage + self.tmem_o_offset = config.tmem_o_offset + self.correction_factor_offset = config.correction_factor_offset + self.iterations_pv_n = config.iterations_pv_n + self.is_var_split_kv = config.is_var_split_kv + self.enable_pdl = config.enable_pdl + self.per_iteration_mma_o = config.per_iteration_mma_o + self.v_dtype = v_dtype + self.o_dtype = o_dtype + + self.variant = fusion.variant + self.has_output_transform = fusion.variant.has_output_transform + self.has_params = fusion.has_params + + self.epilogue_exchange_sync_bar = None + + def set_barriers(self, epilogue_exchange_sync_bar): + """Set named barriers owned by the kernel.""" + self.epilogue_exchange_sync_bar = epilogue_exchange_sync_bar + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def _make_pv_tiled_mma(self): + """Create an independent TiledMma for PV partition shape computation. + + The correction role needs TiledMma only for partition_shape_C and + make_fragment_C — never for actual GEMM. Creating its own instance + avoids sharing mutable state with the MMA role (which mutates TiledMma + via .set(ACCUMULATE, ...)). + """ + cta_group = tcgen05.CtaGroup.TWO + p_major_mode = tcgen05.OperandMajorMode.K + v_major_mode = tcgen05.OperandMajorMode.MN + return sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + v_major_mode, + self.acc_dtype, + cta_group, + self.mma_pv_tiler[:2], + ) + + @cute.jit + def _tmem_load_partition( + self, common_params: SimpleNamespace, pv_tiled_mma: cute.TiledMma, iter_n: int + ) -> tuple[ + cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma + ]: + """Create TMEM load partitions for rescale and epilogue. + + Computes the O accumulator TMEM view at tmem_o_offset, partitions it, + and creates global memory output views for either mAccO (split-KV + workspace) or mO (final output). + """ + tOtO_shape = pv_tiled_mma.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = pv_tiled_mma.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + tOtO = tOtO[None, None, None, iter_n] + + tAcc = tOtO[(None, None), 0, 0] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( + common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) + + gO = None + if cutlass.const_expr(common_params.mAccO is not None): + gO = cute.local_tile( + common_params.mAccO[None, common_params.blk_coord[3], None, None, None], + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor( + common_params.mAccO[ + None, common_params.blk_coord[3], None, None, None + ].shape + ), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + else: + gO = cute.local_tile( + common_params.mO, + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor(common_params.mO.shape), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) + tTR_gO = tmem_load_thr_copy.partition_D(gO) + tTR_cO = tmem_load_thr_copy.partition_D(cO) + tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] + + @cute.jit + def get_correction_factor( + self, + common_params: SimpleNamespace, + p_cor_handle, + ) -> tuple[ + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + cutlass.Int32, + ]: + """Load correction metadata from TMEM written by compute warps. + + Releases the p_cor handle after reading, co-locating the data + consumption with its release. + + Returns (row_sum, row_max, correction_factor, no_correction). + """ + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + _, tAcc, _, _, _, _ = self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, 0 + ) + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) + corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) + tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) + cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + cute.copy( + corr_tmem_load_tiled_copy, + tCor_for_copy[None, None, None, p_cor_handle.index], + rCor, + ) + row_sum = rCor[0] + row_max = rCor[1] + correction_factor = rCor[2] + no_correction = rCor_int[3] + + p_cor_handle.release() + return row_sum, row_max, correction_factor, no_correction + + @cute.jit + def _rescale_one_iter( + self, + common_params: SimpleNamespace, + correction_factor: cutlass.Float32, + skip_correction: cutlass.Boolean, + iter_n: int, + ): + """Rescale O accumulator for a single iter_n slice. + + Side-effect-only (TMEM load/store + fence). Pipeline ops stay in caller. + """ + if not skip_correction: + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = tTR_rAcc[i] * correction_factor + cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) + + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def rescale( + self, + common_params: SimpleNamespace, + correction_factor: cutlass.Float32, + no_correction: cutlass.Int32, + mma_o_handle, + ): + """Rescale O accumulator in TMEM by correction_factor (FP16 single-handle path). + + Releases the mma_o handle after fence_view_async_tmem_store, + co-locating the fence+release pair. Uses vote_all_sync to skip + rescaling when all threads agree no correction is needed. + """ + skip_correction = cute.arch.vote_all_sync(no_correction == 1) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + self._rescale_one_iter( + common_params, correction_factor, skip_correction, iter_n + ) + mma_o_handle.release() + + @cute.jit + def _epilogue_one_iter( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + tidx: cutlass.Int32, + iter_n: int, + params: cute.Tensor = None, + ): + """Epilogue for a single iter_n slice. + + Side-effect-only (TMEM load, global store, fence). Pipeline ops stay in caller. + """ + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition(common_params, common_params.tiled_mma_pv, iter_n) + ) + + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + + if cutlass.const_expr(not self.has_output_transform): + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = ( + tTR_rAcc[i] + * epilogue_params.output_scale + * cute.arch.rcp_approx(row_sum) + ) + else: + if cutlass.const_expr(self.has_params): + self.variant.params = params + rcp_d = ( + cute.arch.rcp_approx(row_sum) if row_max != -self.acc_dtype.inf else 0.0 + ) + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=False, unroll_full=True + ): + qo_head_idx = tTR_cO[i][0] + common_params.cta_m_offset + tTR_rAcc[i] = self.variant.transform_output( + tTR_rAcc[i], + common_params.blk_coord[2], + common_params.blk_coord[1], + qo_head_idx, + row_max, + rcp_d, + epilogue_params.output_scale, + ) + + tR2G_rO_src = None + tR2G_rO_dst = tTR_gO + if cutlass.const_expr(common_params.mAccO is None): + tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) + tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) + else: + tR2G_rO_src = tTR_rAcc + + if cute.elem_less(tTR_cO[0][0], common_params.H): + cute.autovec_copy( + tR2G_rO_src, + tR2G_rO_dst, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + gLSE = None + cLSE = None + if cutlass.const_expr(epilogue_params.mAccLSE is None): + gLSE = cute.local_tile( + epilogue_params.mLSE, + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor(epilogue_params.mLSE.shape), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + else: + gLSE = cute.local_tile( + epilogue_params.mAccLSE[None, common_params.blk_coord[3], None, None], + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ].shape + ), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + lse = ( + cute.math.log2(row_sum, fastmath=True) + + epilogue_params.softmax_scale_log2 * row_max + ) + if cutlass.const_expr(self.warps_in_n == 2): + if cute.elem_less(cLSE[tidx][0], common_params.H): + gLSE[tidx] = lse + + cute.arch.fence_view_async_tmem_load() + + @cute.jit + def epilogue( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + mma_o_handle, + params: cute.Tensor = None, + ): + """Final epilogue: normalize O, convert dtype, write O and LSE (FP16 single-handle path). + + Releases the mma_o handle after fence_view_async_tmem_load, + co-locating the fence+release pair. + """ + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + assert self.epilogue_exchange_sync_bar is not None + self.epilogue_exchange_sync_bar.arrive_and_wait() + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ] + ) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + self._epilogue_one_iter( + common_params, + epilogue_params, + row_sum, + row_max, + tidx, + iter_n, + params, + ) + mma_o_handle.release() + + @cute.jit + def run( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + tile_sched_params: MLAStaticTileSchedulerParams, + tmem_ptr, + p_cor_consumer: PipelineConsumer, + mma_o_consumer: PipelineConsumer, + compute_common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + params: cute.Tensor = None, + ): + """Tile-scheduler loop for the correction warp. + + For each work tile: loads correction factors, rescales O (if not the + first KV tile), and runs the epilogue on the final KV tile. + + State transitions (wait_and_advance) happen here; release calls are + co-located with fences inside get_correction_factor/rescale/epilogue. + """ + tidx, _, _ = cute.arch.thread_idx() + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + pv_tiled_mma = self._make_pv_tiled_mma() + common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=compute_common_params.smem_exchange, + mAccO=compute_common_params.mAccO, + mO=compute_common_params.mO, + K=cache_seqs[blk_coord[2]], + L=compute_common_params.L, + H=compute_common_params.H, + cta_m_offset=compute_common_params.cta_m_offset, + tmem_ptr=tmem_ptr, + tidx=tidx, + tiled_mma_pv=pv_tiled_mma, + ) + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + p_cor_handle = p_cor_consumer.wait_and_advance() + row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_handle) + ) + + if k_tile_count_init != k_tile_count: + if cutlass.const_expr(self.per_iteration_mma_o): + skip_correction = cute.arch.vote_all_sync( + no_correction == 1 + ) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + mma_o_handle = mma_o_consumer.wait_and_advance() + self._rescale_one_iter( + common_params, + correction_factor, + skip_correction, + iter_n, + ) + mma_o_handle.release() + else: + mma_o_handle = mma_o_consumer.wait_and_advance() + self.rescale( + common_params, + correction_factor, + no_correction, + mma_o_handle, + ) + + k_tile_count = k_tile_count - 1 + if k_tile_count == 0: + if cutlass.const_expr(self.per_iteration_mma_o): + tidx = common_params.tidx % ( + self.num_compute_warps * self.threads_per_warp + ) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + assert self.epilogue_exchange_sync_bar is not None + self.epilogue_exchange_sync_bar.arrive_and_wait() + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) + % ( + self.num_compute_warps + * self.threads_per_warp + ) + ] + ) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + mma_o_handle = mma_o_consumer.wait_and_advance() + self._epilogue_one_iter( + common_params, + epilogue_params, + row_sum, + row_max, + tidx, + iter_n, + params, + ) + mma_o_handle.release() + else: + mma_o_handle = mma_o_consumer.wait_and_advance() + self.epilogue( + common_params, + epilogue_params, + row_sum, + row_max, + mma_o_handle, + params, + ) + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() diff --git a/flashinfer/cute_dsl/attention/roles/mla_loader.py b/flashinfer/cute_dsl/attention/roles/mla_loader.py new file mode 100644 index 0000000000..f41bb36396 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_loader.py @@ -0,0 +1,513 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLALoaderRole — TMA load orchestration for MLA decode kernels. + +Extracted from the monolithic mla_decode_fp16.py kernel. Owns: +- get_k_tile_count: compute per-CTA tile range from split-KV partitioning +- TMA copy helpers for Q, K (latent/rope), and V loads +- run(): tile-scheduler loop driving the full load warp lifetime + +All pipeline acquire/wait/commit/release/tail calls happen directly in run(), +not in sub-methods, because CuTe DSL's MLIR compiler cannot track participant +state mutations across method boundaries inside dynamic loops. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.cpasync as cpasync +from cutlass.pipeline import PipelineProducer, PipelineConsumer +from types import SimpleNamespace + +from ..mla_config import MLAConfig +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, + ceil_div, +) + + +class MLALoaderRole: + """Loader warp for MLA decode kernels — TMA loads Q, K, V into SMEM. + + Created from MLAConfig in the kernel's __init__. + """ + + def __init__(self, config: MLAConfig): + self.config = config + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_qk_rope_tiler = config.mma_qk_rope_tiler + self.mma_pv_tiler = config.mma_pv_tiler + self.page_size = config.page_size + self.is_var_split_kv = config.is_var_split_kv + self.iterations_qk_latent = config.iterations_qk_latent + self.iterations_qk_rope = config.iterations_qk_rope + self.iterations_pv_k = config.iterations_pv_k + self.iterations_pv_n = config.iterations_pv_n + + # ========================================================================= + # Tile count computation + # ========================================================================= + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get the current k_index, k_tile_count, and local split_kv value. + + :param split_kv: Split_kv value + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param blk_coord: Block coordinate + :type blk_coord: cute.Coord + :return: k_index, k_tile_count, split_kv + :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + # ========================================================================= + # TMA copy helpers (pure computation — no pipeline ops) + # ========================================================================= + + @cute.jit + def _read_qk_page_indices( + self, + tile_params: SimpleNamespace, + qk_params: SimpleNamespace, + page_table_stage: cutlass.Int32, + ): + """Read QK page table indices from SMEM into registers.""" + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, + qk_params.tiled_mma_qk.thr_id.shape, + ) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + tile_params.sPT[0, page_table_stage] + if self.mma_qk_tiler[1] // self.page_size == 1 + else tile_params.sPT[ + i + tile_params.blk_coord[0] * page_per_tile, page_table_stage + ] + ) + return k_idx + + @cute.jit + def _load_q_tma(self, qk_params: SimpleNamespace, q_barrier): + """Issue TMA copies for Q latent and Q rope fragments.""" + for i in cutlass.range(self.iterations_qk_latent): + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, i], + qk_params.tQsQ[None, (i, 0)], + tma_bar_ptr=q_barrier, + ) + for i in cutlass.range(self.iterations_qk_rope): + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ_rope[None, i], + tma_bar_ptr=q_barrier, + ) + + @cute.jit + def _load_kv_latent_one_iter( + self, + qk_params: SimpleNamespace, + k_idx, + kv_barrier, + kv_stage_index: cutlass.Int32, + iteration: int, + ): + """Issue TMA copies for one K-latent iteration.""" + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, + qk_params.tiled_mma_qk.thr_id.shape, + ) + for k in cutlass.range(page_per_tile): + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, iteration, k_idx[k]], + qk_params.tKCsKC[None, k, 0, kv_stage_index], + tma_bar_ptr=kv_barrier, + ) + + @cute.jit + def _load_kv_rope_one_iter( + self, + qk_params: SimpleNamespace, + k_idx, + kv_barrier, + kv_stage_index: cutlass.Int32, + iteration: int, + ): + """Issue TMA copies for one K-rope iteration.""" + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, + qk_params.tiled_mma_qk.thr_id.shape, + ) + for k in cutlass.range(page_per_tile): + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, iteration, k_idx[k]], + qk_params.tKCsKC[None, k, 0, kv_stage_index], + tma_bar_ptr=kv_barrier, + ) + + @cute.jit + def _read_v_page_indices( + self, + tile_params: SimpleNamespace, + page_table_stage: cutlass.Int32, + ): + """Read V page table indices from SMEM into registers.""" + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range(page_per_tile): + k_idx[i] = ( + tile_params.sPT[0, page_table_stage] + if page_per_tile == 1 + else tile_params.sPT[i, page_table_stage] + ) + return k_idx + + @cute.jit + def _load_v_one_iter( + self, + v_params: SimpleNamespace, + k_idx, + kv_barrier, + kv_stage_index: cutlass.Int32, + i: int, + j: int, + ): + """Issue TMA copies for one V iteration.""" + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) + for k in cutlass.range(page_per_subtile): + k_idx_i = k_idx[ + k + + i // ceil_div(self.iterations_pv_k, page_per_tile) * page_per_subtile + ] + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[ + None, + j, + i % ceil_div(self.iterations_pv_k, page_per_tile), + k_idx_i, + ], + v_params.tVCsVC[None, 0, k, kv_stage_index], + tma_bar_ptr=kv_barrier, + ) + + # ========================================================================= + # TMA partition setup (from load_tma lines 1621-1738) + # ========================================================================= + + @cute.jit + def _setup_tma_partitions( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + v_params: SimpleNamespace, + ): + """Set up TMA partitions for Q, K, V tensors and store them into params. + + This is the partition logic from the monolithic load_tma method + (lines 1621-1738), executed once per work tile before the per-k-tile + load loops. + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + cta_m = min( + qk_params.tiled_mma_qk.op.shape_mnk[0] + // qk_params.tiled_mma_qk.thr_id.shape, + self.page_size, + ) + page_tile_size = min(self.page_size, cta_m) + gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) + tSgCL = ( + gCL[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gCL[None, 0, None, None] + ) + gKR = cute.tiled_divide(qk_params.mKR, (page_tile_size, self.mma_qk_tiler[2])) + tSgKR = ( + gKR[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gKR[None, 0, None, None] + ) + + # tma partition for q, k latent/rope + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + + tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ_rope, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgCL, + ) + + _, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgKR, + ) + + tQLgQL = tQLgQL_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + tQRgQR = tQRgQR_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + + # Flatten divide and partition global tensors for V TMA load + page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) + cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape + gCLT = cute.logical_divide(gCLT, (cta_n,))[ + (None, common_params.blk_coord[0]), None, None, None, None + ] + tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) + tOgCLT = tOgCLT[None, 0, 0, None, None, None] + + # tma partition for vc + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + v_params.sVC, + tOgCLT, + ) + + # set extra params + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tQsQ_rope = tQsQ_rope + qk_params.tKCsKC = tKCsKC + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + # ========================================================================= + # run() — tile-scheduler loop driving the full load warp lifetime + # + # All pipeline acquire/wait/commit/release/tail calls live here. + # Sub-methods only receive handles/indices and do pure TMA copies. + # ========================================================================= + + @cute.jit + def run( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + v_params: SimpleNamespace, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + load_q_producer: PipelineProducer, + load_kv_producer: PipelineProducer, + load_pt_consumer: PipelineConsumer, + tile_sched_params: MLAStaticTileSchedulerParams, + ): + """Tile-scheduler loop that orchestrates all TMA loads for the load warp. + + For each work tile produced by the tile scheduler: + 1. Compute k_index / k_tile_count via _get_k_tile_count + 2. Set up TMA partitions via _setup_tma_partitions + 3. Load first QK tile (with Q load) + 4. Loop remaining tiles: load QK + load V for previous tile + 5. Load final V tile + + After all work tiles are exhausted, calls tail() on both producer + participants. + """ + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + tile_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + mPT=common_params.mPT, + sPT=common_params.sPT, + ) + + self._setup_tma_partitions(tile_params, qk_params, v_params) + + # === First QK tile (with Q load) === + pt_handle = load_pt_consumer.wait_and_advance() + k_idx_qk = self._read_qk_page_indices( + tile_params, qk_params, pt_handle.index + ) + + q_handle = load_q_producer.acquire_and_advance() + self._load_q_tma(qk_params, q_handle.barrier) + + for i in cutlass.range(self.iterations_qk_latent): + kv_handle = load_kv_producer.acquire_and_advance() + self._load_kv_latent_one_iter( + qk_params, + k_idx_qk, + kv_handle.barrier, + kv_handle.index, + i, + ) + + for i in cutlass.range(self.iterations_qk_rope): + kv_handle = load_kv_producer.acquire_and_advance() + self._load_kv_rope_one_iter( + qk_params, + k_idx_qk, + kv_handle.barrier, + kv_handle.index, + i, + ) + + k_index += 1 + k_tile_count -= 1 + + while k_tile_count > 0: + prev_pt_handle = pt_handle + + # === Next QK tile (no Q load) === + pt_handle = load_pt_consumer.wait_and_advance() + k_idx_qk = self._read_qk_page_indices( + tile_params, qk_params, pt_handle.index + ) + + for i in cutlass.range(self.iterations_qk_latent): + kv_handle = load_kv_producer.acquire_and_advance() + self._load_kv_latent_one_iter( + qk_params, + k_idx_qk, + kv_handle.barrier, + kv_handle.index, + i, + ) + + for i in cutlass.range(self.iterations_qk_rope): + kv_handle = load_kv_producer.acquire_and_advance() + self._load_kv_rope_one_iter( + qk_params, + k_idx_qk, + kv_handle.barrier, + kv_handle.index, + i, + ) + + # === V tile for previous k-tile === + k_idx_v = self._read_v_page_indices( + tile_params, prev_pt_handle.index + ) + prev_pt_handle.release() + + for i in cutlass.range(self.iterations_pv_k): + for j in cutlass.range(self.iterations_pv_n): + kv_handle = load_kv_producer.acquire_and_advance() + self._load_v_one_iter( + v_params, + k_idx_v, + kv_handle.barrier, + kv_handle.index, + i, + j, + ) + + k_index += 1 + k_tile_count -= 1 + + # === Last V tile === + k_idx_v = self._read_v_page_indices(tile_params, pt_handle.index) + pt_handle.release() + + for i in cutlass.range(self.iterations_pv_k): + for j in cutlass.range(self.iterations_pv_n): + kv_handle = load_kv_producer.acquire_and_advance() + self._load_v_one_iter( + v_params, + k_idx_v, + kv_handle.barrier, + kv_handle.index, + i, + j, + ) + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_producer.tail() + load_kv_producer.tail() diff --git a/flashinfer/cute_dsl/attention/roles/mla_loader_fp8.py b/flashinfer/cute_dsl/attention/roles/mla_loader_fp8.py new file mode 100644 index 0000000000..285a29e0c4 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_loader_fp8.py @@ -0,0 +1,461 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""FP8 MLA Loader Roles — split K and V TMA loading for FP8 MLA decode. + +FP8 replaces the unified load_kv + load_pt pipeline architecture with two +separate TMA loader warps: +- MLALoaderKRole: loads Q (latent+rope) and K (latent+rope) into SMEM. + Page table indices are read directly from global memory (no SMEM staging). +- MLALoaderVRole: loads V (c_latent_transpose) into SMEM. + Also reads page table from global memory directly. + +This eliminates the page-table pipeline and dedicated PT loader warp. + +All pipeline acquire/commit/tail calls happen directly in run(). +Sub-methods only receive handles/indices and do pure TMA copies. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.cpasync as cpasync +from cutlass.pipeline import PipelineProducer +from types import SimpleNamespace + +from ..mla_config import MLAConfig +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, + ceil_div, +) + + +class MLAFP8LoaderKRole: + """Q/K loader warp for FP8 MLA decode. + + Loads Q latent/rope (once per work tile) and K latent/rope (per k-tile) + into separate SMEM buffers. Page table indices are read directly from + global memory — no CpAsync PT pipeline needed. + """ + + def __init__(self, config: MLAConfig): + self.config = config + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_qk_rope_tiler = config.mma_qk_rope_tiler + self.page_size = config.page_size + self.is_var_split_kv = config.is_var_split_kv + self.iterations_qk_latent = config.iterations_qk_latent + self.iterations_qk_rope = config.iterations_qk_rope + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def _setup_tma_partitions( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + ): + """Set up TMA partitions for Q latent/rope and K latent/rope. + + K-rope goes to separate sKC_rope SMEM (unlike FP16 which shares sKC). + """ + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + cta_m = min( + qk_params.tiled_mma_qk.op.shape_mnk[0] + // qk_params.tiled_mma_qk.thr_id.shape, + self.page_size, + ) + page_tile_size = min(self.page_size, cta_m) + gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) + tSgCL = ( + gCL[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gCL[None, 0, None, None] + ) + gKR = cute.tiled_divide( + qk_params.mKR, (page_tile_size, self.mma_qk_rope_tiler[2]) + ) + tSgKR = ( + gKR[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gKR[None, 0, None, None] + ) + + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ_rope, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgCL, + ) + tKCsKC_rope, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + qk_params.sKC_rope, + tSgKR, + ) + + tQLgQL = tQLgQL_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + tQRgQR = tQRgQR_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tQsQ_rope = tQsQ_rope + qk_params.tKCsKC = tKCsKC + qk_params.tKCsKC_rope = tKCsKC_rope + + @cute.jit + def _load_q_tma(self, qk_params: SimpleNamespace, q_barrier): + for i in cutlass.range_constexpr(self.iterations_qk_latent): + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, i], + qk_params.tQsQ[None, (i, 0)], + tma_bar_ptr=q_barrier, + ) + for i in cutlass.range_constexpr(self.iterations_qk_rope): + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ_rope[None, i], + tma_bar_ptr=q_barrier, + ) + + @cute.jit + def _read_page_indices( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + ): + """Read page table indices directly from global memory.""" + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, + qk_params.tiled_mma_qk.thr_id.shape, + ) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.mPT[k_index] + if self.mma_qk_tiler[1] // self.page_size == 1 + else common_params.mPT[ + ( + k_index * qk_params.tiled_mma_qk.thr_id.shape + + common_params.blk_coord[0] + ) + * page_per_tile + + i + ] + ) + return k_idx + + @cute.jit + def _load_k_one_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_barrier, + k_stage_index: cutlass.Int32, + ): + """Load one k-tile of K latent and K rope into SMEM.""" + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, + qk_params.tiled_mma_qk.thr_id.shape, + ) + k_idx = self._read_page_indices(common_params, qk_params, k_index) + + for i in range(self.iterations_qk_latent): + for k in range(page_per_tile): + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, (i, k_stage_index)], + tma_bar_ptr=k_barrier, + ) + + for i in cutlass.range_constexpr(self.iterations_qk_rope): + for k in cutlass.range_constexpr(page_per_tile): + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, i, k_idx[k]], + qk_params.tKCsKC_rope[None, k, 0, k_stage_index], + tma_bar_ptr=k_barrier, + ) + + @cute.jit + def run( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + load_q_producer: PipelineProducer, + load_k_producer: PipelineProducer, + tile_sched_params: MLAStaticTileSchedulerParams, + ): + """Tile-scheduler loop for Q/K loader warp. + + For each work tile: + 1. Set up TMA partitions (creates fresh tile_params to avoid + mutating common_params inside dynamic if) + 2. First tile: load Q + K + 3. Remaining tiles: load K only + """ + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + tile_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + mPT=common_params.mPT, + ) + self._setup_tma_partitions(tile_params, qk_params) + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + load_q = k_tile_count_init == k_tile_count + if load_q: + q_handle = load_q_producer.acquire_and_advance() + self._load_q_tma(qk_params, q_handle.barrier) + + k_handle = load_k_producer.acquire_and_advance() + self._load_k_one_tile( + tile_params, + qk_params, + k_index, + k_handle.barrier, + k_handle.index, + ) + + k_index += 1 + k_tile_count -= 1 + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_producer.tail() + load_k_producer.tail() + + +class MLAFP8LoaderVRole: + """V loader warp for FP8 MLA decode. + + Loads V (c_latent_transpose) into separate SMEM buffer. Page table + indices are read directly from global memory. + """ + + def __init__(self, config: MLAConfig): + self.config = config + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_pv_tiler = config.mma_pv_tiler + self.page_size = config.page_size + self.is_var_split_kv = config.is_var_split_kv + self.iterations_pv_k = config.iterations_pv_k + self.iterations_pv_n = config.iterations_pv_n + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def _setup_tma_partitions( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + ): + """Set up TMA partitions for V (c_latent_transpose).""" + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) + cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape + gCLT = cute.logical_divide(gCLT, (cta_n,))[ + (None, common_params.blk_coord[0]), None, None, None, None + ] + tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) + tOgCLT = tOgCLT[None, 0, 0, None, None, None] + + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + v_params.sVC, + tOgCLT, + ) + + common_params.mPT = mPT + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + @cute.jit + def _load_v_one_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + v_barrier, + v_stage_index: cutlass.Int32, + ): + """Load one k-tile of V into SMEM.""" + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.mPT[k_index] + if page_per_tile == 1 + else common_params.mPT[k_index * page_per_tile + i] + ) + + for j in cutlass.range_constexpr(self.iterations_pv_n): + for i in cutlass.range_constexpr(self.iterations_pv_k): + if cutlass.const_expr(page_per_tile > 1): + for k in cutlass.range_constexpr(page_per_subtile): + k_idx_i = k_idx[k + i * page_per_subtile] + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, 0, k_idx_i], + v_params.tVCsVC[None, 0, k, ((j, i), v_stage_index)], + tma_bar_ptr=v_barrier, + ) + else: + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, i, k_idx[0]], + v_params.tVCsVC[None, 0, 0, ((j, i), v_stage_index)], + tma_bar_ptr=v_barrier, + ) + + @cute.jit + def run( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + load_v_producer: PipelineProducer, + tile_sched_params: MLAStaticTileSchedulerParams, + ): + """Tile-scheduler loop for V loader warp.""" + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + tile_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + mPT=common_params.mPT, + ) + self._setup_tma_partitions(tile_params, v_params) + + while k_tile_count > 0: + v_handle = load_v_producer.acquire_and_advance() + self._load_v_one_tile( + tile_params, + v_params, + k_index, + v_handle.barrier, + v_handle.index, + ) + k_index += 1 + k_tile_count -= 1 + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_v_producer.tail() diff --git a/flashinfer/cute_dsl/attention/roles/mla_mma.py b/flashinfer/cute_dsl/attention/roles/mla_mma.py new file mode 100644 index 0000000000..b421142e90 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_mma.py @@ -0,0 +1,387 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLAMmaRole — MMA warp role for MLA decode attention kernels. + +Extracted from the monolithic mla_decode_fp16.py kernel. Owns: +- Fragment creation for QK and PV GEMMs +- Per-stage GEMM helpers for QK latent/rope and PV +- run(): tile scheduler loop with interleaved QK/PV, pipeline lifecycle + +All pipeline acquire/wait/commit/release/tail calls happen directly in run(), +not in sub-methods, because CuTe DSL compiles Python to MLIR/SSA where the +JIT boundary acts as pass-by-value for DSL metadata (TiledMma fields, +PipelineState). Mutations made inside @cute.jit sub-methods create new SSA +values that are invisible to the caller. + +GEMM helpers take an explicit ``accumulate`` parameter following the FMHA +prefill pattern (roles/mma.py:gemm_pv). The ACCUMULATE flag is set inside +each helper as ``k_block != 0 or accumulate``, making it deterministic from +the parameter and inner loop index. Callers compute the parameter from their +own loop position — never from ``tiled_mma.get()`` after a sub-method return. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.pipeline import PipelineProducer, PipelineConsumer +from types import SimpleNamespace + +from ..mla_config import MLAConfig +from ..mainloop_spec import MLAMainloopSpec +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, +) + + +class MLAMmaRole: + """MMA warp for MLA decode — computes QK and PV GEMMs in TMEM. + + Created from MLAConfig and MLAMainloopSpec in the kernel's __init__. + Does NOT own TMEM alloc/dealloc (that stays in the kernel for coordination). + """ + + def __init__(self, config: MLAConfig, mainloop: MLAMainloopSpec): + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_pv_tiler = config.mma_pv_tiler + self.rope_dim = config.rope_dim + self.latent_dim = config.latent_dim + self.warps_in_n = config.warps_in_n + self.mma_s_stage = config.mma_s_stage + self.tmem_o_offset = config.tmem_o_offset + self.iterations_qk_latent = config.iterations_qk_latent + self.iterations_qk_rope = config.iterations_qk_rope + self.iterations_pv_k = config.iterations_pv_k + self.iterations_pv_n = config.iterations_pv_n + self.enable_pdl = config.enable_pdl + self.is_var_split_kv = config.is_var_split_kv + + # ------------------------------------------------------------------ + # Tile count + # ------------------------------------------------------------------ + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get k_index, k_tile_count, and local split_kv for an MLA work tile. + + :param split_kv: Split_kv value + :param cache_seqs: Cache sequence lengths tensor + :param block_split_kvs: Per-block split_kv values tensor + :param blk_coord: Block coordinate + :return: k_index, k_tile_count, split_kv + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + # ------------------------------------------------------------------ + # GEMM helpers — stateless w.r.t. caller + # + # Each helper takes an explicit ``accumulate`` bool that controls + # whether the first k-block overwrites (False) or accumulates (True). + # Subsequent k-blocks always accumulate. The caller computes the + # flag from its own loop position; the helper never communicates + # state back via TiledMma mutations (they would be invisible to the + # caller due to SSA pass-by-value at the @cute.jit boundary). + # + # Inner k-block loops use ``cutlass.range()`` (dynamic scf.for), + # NOT ``cutlass.range_constexpr()`` (compile-time unroll). + # range_constexpr unrolls tiled_mma.set() calls into the enclosing + # scope, producing SSA values that leak across dynamic while-loop + # yields. range() keeps the .set() inside an scf.for scope where + # SSA carry-through is handled correctly. + # ------------------------------------------------------------------ + + @cute.jit + def _gemm_qk_latent_one_stage( + self, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + s_stage_index: cutlass.Int32, + kv_stage_index: cutlass.Int32, + q_stage: int, + accumulate: bool, + ): + """Compute one QK-latent stage: inner k-block GEMM loop.""" + tStS = qk_params.tStS_staged[None, None, None, s_stage_index] + for k_block in cutlass.range(cute.size(qk_params.tSrQ.shape[2])): + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate) + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kv_stage_index], + tStS, + ) + + @cute.jit + def _gemm_qk_rope_one_stage( + self, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + s_stage_index: cutlass.Int32, + kv_stage_index: cutlass.Int32, + q_stage: int, + accumulate: bool, + ): + """Compute one QK-rope stage: inner k-block GEMM loop.""" + tStS = qk_params.tStS_staged[None, None, None, s_stage_index] + for k_block in cutlass.range(self.rope_dim // tiled_mma_qk.shape_mnk[2]): + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate) + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ_rope[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kv_stage_index], + tStS, + ) + + @cute.jit + def _gemm_pv_one_stage( + self, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + p_stage_index: cutlass.Int32, + kv_stage_index: cutlass.Int32, + p_stage: int, + acc_stage: int, + accumulate: bool, + ): + """Compute one PV stage: inner k-block GEMM loop.""" + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range(pv_params.tOrP.shape[2]): + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate) + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_stage_index), + ], + pv_params.tOrVC[None, None, k_block, kv_stage_index], + tOtO, + ) + + # ------------------------------------------------------------------ + # Orchestration loop — tile scheduler + interleaved QK/PV + # + # All pipeline acquire/wait/commit/release/tail calls live here. + # Sub-methods only receive stage indices and do pure GEMM computation. + # ------------------------------------------------------------------ + + @cute.jit + def run( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer: PipelineConsumer, + load_kv_consumer: PipelineConsumer, + mma_s_producer: PipelineProducer, + p_mma_consumer: PipelineConsumer, + mma_o_producer: PipelineProducer, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + tile_sched_params: MLAStaticTileSchedulerParams, + sQ: cute.Tensor, + sQ_rope: cute.Tensor, + sKC: cute.Tensor, + sP: cute.Tensor, + sVC: cute.Tensor, + tmem_ptr: cute.Tensor, + is_leader_cta: cutlass.Boolean, + L: cutlass.Int32, + ): + """MMA warp orchestration loop for MLA decode. + + Creates MMA fragments, iterates over work tiles via the tile scheduler, + and runs interleaved QK/PV GEMMs with pipeline synchronization. + + Does NOT own TMEM alloc/dealloc — that stays in the kernel for + coordination with other warps. + """ + # Create MMA fragments + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrQ_rope = tiled_mma_qk.make_fragment_A(sQ_rope) + tSrKC = tiled_mma_qk.make_fragment_B(sKC) + tOrP = tiled_mma_pv.make_fragment_A(sP) + tOrVC = tiled_mma_pv.make_fragment_B(sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + # Tile scheduler + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + # Reset PV accumulate for each new work tile + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_qk_params = SimpleNamespace( + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC, + tSrQ=tSrQ, + tSrQ_rope=tSrQ_rope, + tSrKC=tSrKC, + tStS_staged=tStS_staged, + ) + mma_pv_params = SimpleNamespace( + sP=sP, + sVC=sVC, + tOrP=tOrP, + tOrVC=tOrVC, + tOtO_staged=tOtO_staged, + ) + + if is_leader_cta: + q_handle = load_q_consumer.wait_and_advance() + + # === First QK tile === + s_handle = mma_s_producer.acquire_and_advance() + for q_stage in range(self.iterations_qk_latent): + kv_handle = load_kv_consumer.wait_and_advance() + self._gemm_qk_latent_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=(q_stage > 0), + ) + kv_handle.release() + for q_stage in range(self.iterations_qk_rope): + kv_handle = load_kv_consumer.wait_and_advance() + self._gemm_qk_rope_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=True, + ) + kv_handle.release() + s_handle.commit() + k_tile_count -= 1 + + # === Interleaved QK + PV for remaining tiles === + while k_tile_count > 0: + # QK + s_handle = mma_s_producer.acquire_and_advance() + for q_stage in range(self.iterations_qk_latent): + kv_handle = load_kv_consumer.wait_and_advance() + self._gemm_qk_latent_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=(q_stage > 0), + ) + kv_handle.release() + for q_stage in range(self.iterations_qk_rope): + kv_handle = load_kv_consumer.wait_and_advance() + self._gemm_qk_rope_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=True, + ) + kv_handle.release() + s_handle.commit() + + # PV — pv_acc is read at run() level (safe), tracks + # whether any PV block has already initialized TMEM O. + o_handle = mma_o_producer.acquire_and_advance() + p_handle = p_mma_consumer.wait_and_advance() + pv_acc = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + for p_stage in range(self.iterations_pv_k): + for acc_stage in range(self.iterations_pv_n): + kv_handle = load_kv_consumer.wait_and_advance() + self._gemm_pv_one_stage( + mma_pv_params, + tiled_mma_pv, + p_handle.index, + kv_handle.index, + p_stage, + acc_stage, + accumulate=(pv_acc or p_stage > 0), + ) + kv_handle.release() + p_handle.release() + o_handle.commit() + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + + k_tile_count -= 1 + + q_handle.release() + + # === Final PV tile === + o_handle = mma_o_producer.acquire_and_advance() + p_handle = p_mma_consumer.wait_and_advance() + pv_acc = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + for p_stage in range(self.iterations_pv_k): + for acc_stage in range(self.iterations_pv_n): + kv_handle = load_kv_consumer.wait_and_advance() + self._gemm_pv_one_stage( + mma_pv_params, + tiled_mma_pv, + p_handle.index, + kv_handle.index, + p_stage, + acc_stage, + accumulate=(pv_acc or p_stage > 0), + ) + kv_handle.release() + p_handle.release() + o_handle.commit() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Pipeline producer tails + mma_s_producer.tail() + mma_o_producer.tail() diff --git a/flashinfer/cute_dsl/attention/roles/mla_mma_fp8.py b/flashinfer/cute_dsl/attention/roles/mla_mma_fp8.py new file mode 100644 index 0000000000..3284ed4c03 --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_mma_fp8.py @@ -0,0 +1,366 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLAMmaFP8Role — MMA warp role for FP8 MLA decode attention kernels. + +FP8 differs from FP16 in three structural ways: + +1. QK GEMM: single load_k wait covers all latent+rope stages, single release. + K-rope uses separate tSrKC_rope fragments from sKC_rope SMEM. + Fragment indexing: tSrQ[..., (q_stage, 0)], tSrKC[..., (q_stage, kc_stage)]. + +2. PV GEMM: loop order is ``for acc_stage -> mma_o_acquire -> for p_stage``. + One load_v wait covers all p_stage iterations; mma_o produced per acc_stage. + V fragment indexing: tOrVC[..., ((acc_stage, p_stage), vc_stage)]. + +3. mma_o pipeline: 2 stages (vs 1 for FP16), with acquire/commit per acc_stage. + +GEMM helpers use explicit ``accumulate`` parameter and ``cutlass.range()`` +(not ``range_constexpr``) for inner k-block loops, matching the FP16 pattern +in mla_mma.py. ``range()`` generates ``scf.for`` which keeps ``.set()`` SSA +values properly scoped; ``range_constexpr`` unrolls at compile time and leaks +SSA values across dynamic loop boundaries. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.pipeline import PipelineProducer, PipelineConsumer +from types import SimpleNamespace + +from ..mla_config import MLAConfig +from ..mainloop_spec import MLAMainloopSpec +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, +) + + +class MLAMmaFP8Role: + """MMA warp for FP8 MLA decode — computes QK and PV GEMMs in TMEM. + + Does NOT own TMEM alloc/dealloc (that stays in the kernel for coordination). + """ + + def __init__(self, config: MLAConfig, mainloop: MLAMainloopSpec): + self.mma_qk_tiler = config.mma_qk_tiler + self.mma_qk_rope_tiler = config.mma_qk_rope_tiler + self.mma_pv_tiler = config.mma_pv_tiler + self.rope_dim = config.rope_dim + self.latent_dim = config.latent_dim + self.warps_in_n = config.warps_in_n + self.mma_s_stage = config.mma_s_stage + self.tmem_o_offset = config.tmem_o_offset + self.iterations_qk_latent = config.iterations_qk_latent + self.iterations_qk_rope = config.iterations_qk_rope + self.iterations_pv_k = config.iterations_pv_k + self.iterations_pv_n = config.iterations_pv_n + self.enable_pdl = config.enable_pdl + self.is_var_split_kv = config.is_var_split_kv + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + # ------------------------------------------------------------------ + # GEMM helpers — stateless w.r.t. caller + # + # Each helper takes an explicit ``accumulate`` bool that controls + # whether the first k-block overwrites (False) or accumulates (True). + # Subsequent k-blocks always accumulate. The caller computes the + # flag from its own loop position; the helper never communicates + # state back via TiledMma mutations (they would be invisible to the + # caller due to SSA pass-by-value at the @cute.jit boundary). + # + # Inner k-block loops use ``cutlass.range()`` (dynamic scf.for), + # NOT ``cutlass.range_constexpr()`` (compile-time unroll). + # range_constexpr unrolls tiled_mma.set() calls into the enclosing + # scope, producing SSA values that leak across dynamic while-loop + # yields. range() keeps the .set() inside an scf.for scope where + # SSA carry-through is handled correctly. + # ------------------------------------------------------------------ + + @cute.jit + def _gemm_qk_latent_one_stage( + self, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + s_stage_index: cutlass.Int32, + kv_stage_index: cutlass.Int32, + q_stage: int, + accumulate: bool, + ): + """Compute one QK-latent stage: inner k-block GEMM loop.""" + tStS = qk_params.tStS_staged[None, None, None, s_stage_index] + for k_block in cutlass.range(cute.size(qk_params.tSrQ.shape[2])): + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate) + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, (q_stage, 0)], + qk_params.tSrKC[None, None, k_block, (q_stage, kv_stage_index)], + tStS, + ) + + @cute.jit + def _gemm_qk_rope_one_stage( + self, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + s_stage_index: cutlass.Int32, + kv_stage_index: cutlass.Int32, + q_stage: int, + accumulate: bool, + ): + """Compute one QK-rope stage using separate tSrKC_rope fragments.""" + tStS = qk_params.tStS_staged[None, None, None, s_stage_index] + for k_block in cutlass.range(self.rope_dim // tiled_mma_qk.shape_mnk[2]): + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate) + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ_rope[None, None, k_block, q_stage], + qk_params.tSrKC_rope[None, None, k_block, kv_stage_index], + tStS, + ) + + @cute.jit + def _gemm_pv_one_stage( + self, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + p_stage_index: cutlass.Int32, + vc_stage_index: cutlass.Int32, + p_stage: int, + acc_stage: int, + accumulate: bool, + ): + """Compute one PV stage: inner k-block GEMM loop.""" + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range(pv_params.tOrP.shape[2]): + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, k_block != 0 or accumulate) + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_stage_index), + ], + pv_params.tOrVC[ + None, None, k_block, ((acc_stage, p_stage), vc_stage_index) + ], + tOtO, + ) + + # ------------------------------------------------------------------ + # Orchestration loop + # ------------------------------------------------------------------ + + @cute.jit + def run( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer: PipelineConsumer, + load_k_consumer: PipelineConsumer, + load_v_consumer: PipelineConsumer, + mma_s_producer: PipelineProducer, + p_mma_consumer: PipelineConsumer, + mma_o_producer: PipelineProducer, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + tile_sched_params: MLAStaticTileSchedulerParams, + sQ: cute.Tensor, + sQ_rope: cute.Tensor, + sKC: cute.Tensor, + sKC_rope: cute.Tensor, + sP: cute.Tensor, + sVC: cute.Tensor, + tmem_ptr: cute.Tensor, + is_leader_cta: cutlass.Boolean, + L: cutlass.Int32, + ): + """MMA warp orchestration for FP8 MLA decode.""" + tSrQ = tiled_mma_qk.make_fragment_A(sQ) + tSrQ_rope = tiled_mma_qk.make_fragment_A(sQ_rope) + tSrKC = tiled_mma_qk.make_fragment_B(sKC) + tSrKC_rope = tiled_mma_qk.make_fragment_B(sKC_rope) + tOrP = tiled_mma_pv.make_fragment_A(sP) + tOrVC = tiled_mma_pv.make_fragment_B(sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Track PV accumulate state as a plain bool instead of + # tiled_mma_pv.set/get to avoid carrying TiledMma SSA values + # across the dynamic while-loop yield. + pv_accumulated = False + + while work_tile.is_valid_tile: + pv_accumulated = False + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self._get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_qk_params = SimpleNamespace( + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC, + sKC_rope=sKC_rope, + tSrQ=tSrQ, + tSrQ_rope=tSrQ_rope, + tSrKC=tSrKC, + tSrKC_rope=tSrKC_rope, + tStS_staged=tStS_staged, + ) + mma_pv_params = SimpleNamespace( + sP=sP, + sVC=sVC, + tOrP=tOrP, + tOrVC=tOrVC, + tOtO_staged=tOtO_staged, + ) + + if is_leader_cta: + # === First QK tile (with Q wait) === + q_handle = load_q_consumer.wait_and_advance() + + s_handle = mma_s_producer.acquire_and_advance() + kv_handle = load_k_consumer.wait_and_advance() + for q_stage in range(self.iterations_qk_latent): + self._gemm_qk_latent_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=(q_stage > 0), + ) + for q_stage in range(self.iterations_qk_rope): + self._gemm_qk_rope_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=True, + ) + kv_handle.release() + s_handle.commit() + k_tile_count -= 1 + + # === Interleaved QK + PV for remaining tiles === + while k_tile_count > 0: + # QK + s_handle = mma_s_producer.acquire_and_advance() + kv_handle = load_k_consumer.wait_and_advance() + for q_stage in range(self.iterations_qk_latent): + self._gemm_qk_latent_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=(q_stage > 0), + ) + for q_stage in range(self.iterations_qk_rope): + self._gemm_qk_rope_one_stage( + mma_qk_params, + tiled_mma_qk, + s_handle.index, + kv_handle.index, + q_stage, + accumulate=True, + ) + kv_handle.release() + s_handle.commit() + + # PV + p_handle = p_mma_consumer.wait_and_advance() + v_handle = load_v_consumer.wait_and_advance() + for acc_stage in range(self.iterations_pv_n): + o_handle = mma_o_producer.acquire_and_advance() + for p_stage in range(self.iterations_pv_k): + self._gemm_pv_one_stage( + mma_pv_params, + tiled_mma_pv, + p_handle.index, + v_handle.index, + p_stage, + acc_stage, + accumulate=(pv_accumulated or p_stage > 0), + ) + o_handle.commit() + v_handle.release() + p_handle.release() + pv_accumulated = True + + k_tile_count -= 1 + + q_handle.release() + + # === Final PV tile === + p_handle = p_mma_consumer.wait_and_advance() + v_handle = load_v_consumer.wait_and_advance() + for acc_stage in range(self.iterations_pv_n): + o_handle = mma_o_producer.acquire_and_advance() + for p_stage in range(self.iterations_pv_k): + self._gemm_pv_one_stage( + mma_pv_params, + tiled_mma_pv, + p_handle.index, + v_handle.index, + p_stage, + acc_stage, + accumulate=(pv_accumulated or p_stage > 0), + ) + o_handle.commit() + v_handle.release() + p_handle.release() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_producer.tail() + mma_o_producer.tail() diff --git a/flashinfer/cute_dsl/attention/roles/mla_pt_loader.py b/flashinfer/cute_dsl/attention/roles/mla_pt_loader.py new file mode 100644 index 0000000000..641150dcbf --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mla_pt_loader.py @@ -0,0 +1,109 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MLAPageTableLoaderRole — page-table producer warp for MLA decode. + +Owns the tile-scheduler loop for the page-table warp, issuing async copies +of page indices from global memory into SMEM for each k-tile. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.cpasync as cpasync +from cutlass.pipeline import PipelineProducer + +from ..mla_config import MLAConfig +from ..scheduler.mla_persistent import ( + create_mla_static_tile_scheduler, + MLAStaticTileSchedulerParams, + ceil_div, +) + + +class MLAPageTableLoaderRole: + def __init__(self, config: MLAConfig): + self.mma_qk_tiler = config.mma_qk_tiler + self.page_size = config.page_size + self.is_var_split_kv = config.is_var_split_kv + self.load_pt_stage = config.load_pt_stage + self.threads_per_warp = 32 + + @cute.jit + def _get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def run( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + load_pt_producer: PipelineProducer, + mPT: cute.Tensor, + sPT: cute.Tensor, + tile_sched_params: MLAStaticTileSchedulerParams, + ): + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx % self.threads_per_warp + page_per_tile = self.mma_qk_tiler[1] // self.page_size + elem_per_thread = ceil_div(page_per_tile, self.threads_per_warp) + + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + cutlass.Int32, + num_bits_per_copy=cutlass.Int32.width, + ) + + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, _ = self._get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + + mPT_seq = mPT[None, blk_coord[2]] + mPT_for_copy = cute.flat_divide(mPT_seq, (1,)) + sPT_for_copy = cute.flat_divide(sPT, (1,)) + + while k_tile_count > 0: + handle = load_pt_producer.acquire_and_advance() + + for i in range(elem_per_thread): + idx = i * self.threads_per_warp + tidx + if cute.elem_less( + k_index * page_per_tile + idx, mPT_seq.shape[0] + ) and cute.elem_less(idx, page_per_tile): + cute.copy( + atom_async_copy, + mPT_for_copy[None, k_index * page_per_tile + idx], + sPT_for_copy[None, idx, handle.index], + ) + else: + sPT_for_copy[None, idx, handle.index].fill(0) + + handle.commit() + k_index += 1 + k_tile_count -= 1 + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_pt_producer.tail() diff --git a/flashinfer/cute_dsl/attention/roles/mma.py b/flashinfer/cute_dsl/attention/roles/mma.py new file mode 100644 index 0000000000..9da4f1dfaa --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/mma.py @@ -0,0 +1,286 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""MmaOps — QK/PV GEMM primitives and orchestration for attention kernels. + +Reusable primitives (pipeline-unaware, for composing new kernel variants): +- gemm_qk(): single QK GEMM with kphase unrolling +- gemm_pv(): single PV GEMM with configurable accumulation +- alloc_tmem(): TMEM allocation with barrier sync +- dealloc_tmem(): TMEM deallocation after barrier wait + +Orchestration (prefill-specific, uses raw CuTe ops for JIT compatibility): +- run(): double-buffered interleaved QK/PV with S0/S1 and O0/O1 +""" + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.cute.typing import Int32, Float32 + +from cutlass.pipeline import PipelineProducer, PipelineConsumer + +from ..config import AttentionConfig +from ..fusion.mask import get_trip_count +from ..scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, +) + + +class MmaRole: + """MMA warp for attention kernels — computes QK and PV GEMMs. + + Created from AttentionConfig in the kernel's __init__. + """ + + def __init__( + self, + config: AttentionConfig, + tmem_alloc_cols, + tmem_alloc_sync_bar_id, + threads_per_warp, + has_logits_transform: bool = False, + ): + self.cta_tiler = config.cta_tiler + self.mask_type = config.mask_type + self.window_left = config.window_left + self.tmem_alloc_cols = tmem_alloc_cols + self.tmem_alloc_sync_bar_id = tmem_alloc_sync_bar_id + self.threads_per_warp = threads_per_warp + self.has_logits_transform = has_logits_transform + + # ========================================================================= + # Reusable primitives — no pipeline awareness, for composing new kernels + # + # All primitives below are SAFE to call from run() and other @cute.jit + # methods. They are void (no return values) and only use compile-time + # indexing (unrolled kphase loops), avoiding the CuTe DSL JIT + # limitations with runtime tensor views and return values. + # ========================================================================= + + @cute.jit + def gemm_qk( + self, + tiled_mma: cute.TiledMma, + tStS: cute.Tensor, + tSrQ_slice: cute.Tensor, + tSrK_slice: cute.Tensor, + ): + """Single QK GEMM: S += Q * K^T with kphase unrolling. + + Always starts a fresh accumulation (first kphase non-accumulate). + """ + num_kphases = cute.size(tSrQ_slice, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + coord = (None, None, kphase_idx) + tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm(tiled_mma, tStS, tSrQ_slice[coord], tSrK_slice[coord], tStS) + + @cute.jit + def gemm_pv( + self, + tiled_mma: cute.TiledMma, + tOtO: cute.Tensor, + tOrP: cute.Tensor, + tOrV_slice: cute.Tensor, + accumulate: bool, + ): + """Single PV GEMM: O += P * V with kphase unrolling. + + Args: + accumulate: If False, first kphase starts fresh (non-accumulate), + rest accumulate. If True, all kphases accumulate. + """ + num_kphases = cute.size(tOrP, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + coord = (None, None, kphase_idx) + tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0 or accumulate) + cute.gemm(tiled_mma, tOtO, tOrP[coord], tOrV_slice[coord], tOtO) + + @cute.jit + def alloc_tmem(self, storage: cute.Tensor): + """Allocate TMEM buffer and synchronize.""" + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.barrier( + barrier_id=self.tmem_alloc_sync_bar_id, + number_of_threads=self.threads_per_warp, + ) + + @cute.jit + def dealloc_tmem(self, storage: cute.Tensor, tmem_dealloc_mbar_ptr: Int32): + """Wait for all warps, then deallocate TMEM buffer.""" + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # ========================================================================= + # Prefill orchestration — uses primitives for GEMMs and TMEM lifecycle + # ========================================================================= + + @cute.jit + def run( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tStS0: cute.Tensor, + tStS1: cute.Tensor, + tOtO0: cute.Tensor, + tOtO1: cute.Tensor, + tSrQ: cute.Tensor, + tSrK: cute.Tensor, + tOrP0: cute.Tensor, + tOrP1: cute.Tensor, + tOrV: cute.Tensor, + seqlen_q_global: Int32, + seqlen_k_global: Int32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + load_q_consumer: PipelineConsumer, + load_kv_consumer: PipelineConsumer, + mma_s0_producer: PipelineProducer, + mma_s1_producer: PipelineProducer, + mma_corr_producer: PipelineProducer | None, + tile_sched_params: FmhaStaticTileSchedulerParams, + storage: cute.Tensor, + tmem_dealloc_mbar_ptr: Int32, + ): + """MMA warp orchestration loop (prefill-specific). + + Double-buffered interleaved QK/PV GEMMs with pipeline synchronization. + + For has_logits_transform variants, mma_corr_producer is None. PV GEMM + results piggyback on subsequent QK tcgen05.commit() calls to mma_s0/s1, + which makes all prior TMEM writes visible to softmax warps. + """ + # Alloc tmem buffer + self.alloc_tmem(storage) + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_q_ = seqlen_q_global + continue_cond = False + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q_ = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q_, + ) + ) + + if not continue_cond: + seqlen_k = seqlen_k_global + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + # GEMM_QK00 (Q0 * K0 -> S0) + q0_handle_consumer = load_q_consumer.wait_and_advance() + tSrQ0 = tSrQ[None, None, None, q0_handle_consumer.index] + k_handle_consumer = load_kv_consumer.wait_and_advance() + tSrK0 = tSrK[None, None, None, k_handle_consumer.index] + s0_handle_producer = mma_s0_producer.acquire_and_advance() + self.gemm_qk(qk_tiled_mma, tStS0, tSrQ0, tSrK0) + s0_handle_producer.commit() + + # GEMM_QK10 (Q1 * K0 -> S1) + q1_handle_consumer = load_q_consumer.wait_and_advance() + tSrQ1 = tSrQ[None, None, None, q1_handle_consumer.index] + s1_handle_producer = mma_s1_producer.acquire_and_advance() + self.gemm_qk(qk_tiled_mma, tStS1, tSrQ1, tSrK0) + s1_handle_producer.commit() + k_handle_consumer.release() + + # GEMM_PV00 (P0 * V0 -> O0_partial) + v_handle_consumer = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle_consumer.index] + if cutlass.const_expr(not self.has_logits_transform): + o0_handle_producer = mma_corr_producer.acquire_and_advance() + s0_handle_producer = mma_s0_producer.acquire_and_advance() + self.gemm_pv(pv_tiled_mma, tOtO0, tOrP0, tOrVi, False) + if cutlass.const_expr(not self.has_logits_transform): + o0_handle_producer.commit() + + seqlen_kv_loop_steps = ( + get_trip_count( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k, + seqlen_q_, + ) + - 1 + ) + + pv_whether_acc = False + for _i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # GEMM_QK0i + k_handle_consumer = load_kv_consumer.wait_and_advance() + tSrKi = tSrK[None, None, None, k_handle_consumer.index] + self.gemm_qk(qk_tiled_mma, tStS0, tSrQ0, tSrKi) + s0_handle_producer.commit() + + # GEMM_PV1(i-1) + if cutlass.const_expr(not self.has_logits_transform): + o1_handle_producer = mma_corr_producer.acquire_and_advance() + s1_handle_producer = mma_s1_producer.acquire_and_advance() + self.gemm_pv(pv_tiled_mma, tOtO1, tOrP1, tOrVi, pv_whether_acc) + pv_whether_acc = True + if cutlass.const_expr(not self.has_logits_transform): + o1_handle_producer.commit() + v_handle_consumer.release() + + # GEMM_QK1i + self.gemm_qk(qk_tiled_mma, tStS1, tSrQ1, tSrKi) + s1_handle_producer.commit() + k_handle_consumer.release() + + # GEMM_PV0i + v_handle_consumer = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle_consumer.index] + if cutlass.const_expr(not self.has_logits_transform): + o0_handle_producer = mma_corr_producer.acquire_and_advance() + s0_handle_producer = mma_s0_producer.acquire_and_advance() + self.gemm_pv(pv_tiled_mma, tOtO0, tOrP0, tOrVi, True) + if cutlass.const_expr(not self.has_logits_transform): + o0_handle_producer.commit() + + # release Q0 & Q1 + q0_handle_consumer.release() + q1_handle_consumer.release() + + # GEMM_PV1(end) + if cutlass.const_expr(not self.has_logits_transform): + o1_handle = mma_corr_producer.acquire_and_advance() + s1_handle_producer = mma_s1_producer.acquire_and_advance() + self.gemm_pv(pv_tiled_mma, tOtO1, tOrP1, tOrVi, pv_whether_acc) + if cutlass.const_expr(not self.has_logits_transform): + o1_handle.commit() + v_handle_consumer.release() + + s0_handle_producer.commit() + s1_handle_producer.commit() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # dealloc tmem buffer + self.dealloc_tmem(storage, tmem_dealloc_mbar_ptr) diff --git a/flashinfer/cute_dsl/attention/roles/softmax.py b/flashinfer/cute_dsl/attention/roles/softmax.py new file mode 100644 index 0000000000..5dd20ec4aa --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/softmax.py @@ -0,0 +1,724 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""SoftmaxRole — online softmax computation for attention kernels. + +Handles: +- Row-max tracking and exp2 computation +- Row-sum accumulation +- KV-dimension masking (causal, sliding window, residual) +- Logits transform hooks via AttentionFusion +- Pipeline synchronization between MMA and correction stages + +Extracted from BlackwellFusedMultiHeadAttentionForward.softmax / softmax_step. +""" + +from typing import Optional, Tuple, Type + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.typing import Int32, Float32 + +from cutlass.pipeline import PipelineProducer, PipelineConsumer + +from .softmax_math import exp2_scale +from ..config import AttentionConfig, AttentionFusion +from ..tmem_layout import TmemLayout +from ..fusion.mask import ( + apply_mask, + get_unmasked_trip_count, + get_masked_trip_count, + get_kv_start_block_idx, +) +from ..scheduler.persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, +) + + +class SoftmaxRole: + """Online softmax warp group for attention kernels. + + Created from AttentionConfig + AttentionFusion + TmemLayout in the kernel's __init__. + Tensor-type attributes (q_dtype, o_dtype) are set later via set_dtypes() in __call__. + """ + + def __init__( + self, + config: AttentionConfig, + fusion: AttentionFusion, + tmem: TmemLayout, + softmax0_warp_ids, + softmax1_warp_ids, + threads_per_warp, + ): + # From config + self.qk_acc_dtype = config.qk_acc_dtype + self.qk_mma_tiler = config.qk_mma_tiler + self.pv_mma_tiler = config.pv_mma_tiler + self.pv_acc_dtype = config.pv_acc_dtype + self.cta_tiler = config.cta_tiler + self.mask_type = config.mask_type + self.window_left = config.window_left + self.num_repeat_kv_heads = config.num_repeat_kv_heads + + # From TMEM layout + self.tmem_vec0_offset = tmem.vec0_offset + self.tmem_vec1_offset = tmem.vec1_offset + self.tmem_p0_offset = tmem.p0_offset + self.tmem_p1_offset = tmem.p1_offset + + # From fusion variant + self.variant = fusion.variant + self.has_score_mod = fusion.variant.has_score_mod + self.has_logits_transform = fusion.variant.has_logits_transform + self.has_vectorized_logits_transform = ( + fusion.variant.has_vectorized_logits_transform + ) + self.has_statistics_update = fusion.variant.has_statistics_update + self.has_output_transform = fusion.variant.has_output_transform + self.has_params = fusion.has_params + + # Warp config + self.softmax0_warp_ids = softmax0_warp_ids + self.softmax1_warp_ids = softmax1_warp_ids + self.threads_per_warp = threads_per_warp + + # Set later via set_dtypes() / set_call_attrs() + self.q_dtype: Optional[Type[cutlass.Numeric]] = None + self.o_dtype: Optional[Type[cutlass.Numeric]] = None + self.o_layout = None + self.epi_tile = None + + def set_dtypes(self, q_dtype, o_dtype): + """Set tensor-type attributes known only at call time.""" + self.q_dtype = q_dtype + self.o_dtype = o_dtype + + def set_call_attrs(self, o_layout, epi_tile): + """Set epilog attributes for transform path (replaces CorrectionRole).""" + self.o_layout = o_layout + self.epi_tile = epi_tile + + @cute.jit + def step( + self, + stage: int, + need_apply_mask: bool, + iter_args: tuple, + value_args: tuple, + pipeline_args: tuple, + atom_args: tuple, + tensor_args: tuple, + params: cute.Tensor | None, + ) -> Tuple[ + Float32, + Float32, + PipelineProducer.ImmutableResourceHandle, + PipelineConsumer, + PipelineProducer, + PipelineConsumer, + PipelineProducer, + ]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + """ + assert self.q_dtype is not None + assert self.o_dtype is not None + cS, row_max, row_sum, vec_i_handle, batch_coord, head_coord = iter_args + qo_head_idx = head_coord + kv_head_idx = qo_head_idx // self.num_repeat_kv_heads + kv_tile_idx = cS[0][1] // self.qk_mma_tiler[1] + + seqlen_q, seqlen_k, scale_softmax_log2 = value_args + ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = pipeline_args + ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) = atom_args + ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) = tensor_args + + if cutlass.const_expr(self.has_statistics_update): + self.variant.params = params + row_max, row_sum = self.variant.update_statistics( + kv_tile_idx, + qo_head_idx, + row_max, + row_sum, + scale_softmax_log2, + ) + + tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + tScS = qk_thr_mma.partition_C(cS) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tScS_P_layout = cute.composition( + tScS.layout, cute.make_layout((128, tilePlikeFP32)) + ) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + tTMEM_LOADcS = thr_tmem_load.partition_D(tScS) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P) + + # Wait for Si + si_handle = mma_si_consumer.wait_and_advance() + tTMEM_LOADrS = cute.make_fragment(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + if need_apply_mask: + apply_mask( + self.mask_type, + self.window_left, + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_k, + seqlen_k - seqlen_q, + ) + + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(frg_tile)) + + if cutlass.const_expr(self.has_score_mod): + if cutlass.const_expr(self.has_params): + self.variant.params = params + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + qo_idx, kv_idx = tTMEM_LOADcS_frg[k, j] + tTMEM_LOADrS_frg[k, j] = self.variant.score_mod( + tTMEM_LOADrS_frg[k, j], + batch_coord, + qo_idx, + kv_idx, + qo_head_idx, + kv_head_idx, + ) + # Re-apply masking: score_mod may map -inf to a finite value + # (e.g. SoftCapping: cap*tanh(-inf/cap) = -cap). Restore -inf + # for out-of-bounds positions so they get zero softmax weight. + if need_apply_mask: + apply_mask( + self.mask_type, + self.window_left, + tTMEM_LOADrS, + tTMEM_LOADcS, + seqlen_k, + seqlen_k - seqlen_q, + ) + + if cutlass.const_expr(not self.has_logits_transform): + old_row_max = row_max + row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) + row_max_safe = row_max + + if row_max == -cutlass.Float32.inf: + row_max_safe = 0.0 + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + + tTMEM_STORE_VECrS[0] = old_row_max + tTMEM_STORE_VECrS[1] = row_max_safe + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + vec_i_handle.commit() + + tTMEM_STORErS_x4 = cute.make_fragment(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout, + ) + + scale = scale_softmax_log2 + + if cutlass.const_expr(not self.has_logits_transform): + if cutlass.const_expr(stage == 0): + sequence_producer_handle = s0_s1_sequence_producer.acquire_and_advance() + else: + sequence_consumer_handle = s0_s1_sequence_consumer.wait_and_advance() + tTMEM_STORErS_x4_e_frg = cute.logical_divide( + tTMEM_STORErS_x4_e, cute.make_layout(frg_tile) + ) + ### the softmax computation part ### e^(xi*scale - mi*scale) + if cutlass.const_expr(self.has_vectorized_logits_transform): + if cutlass.const_expr(self.has_params and not self.has_score_mod): + self.variant.params = params + for j in range(frg_cnt): + self.variant.transform_logits_vec(tTMEM_LOADrS_frg[None, j]) + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + + elif cutlass.const_expr(self.has_logits_transform): + if cutlass.const_expr(self.has_params and not self.has_score_mod): + self.variant.params = params + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + tTMEM_LOADrS_frg[k, j] = self.variant.transform_logits( + tTMEM_LOADrS_frg[k, j], + ) + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + + else: + for j in range(frg_cnt): + exp2_scale(tTMEM_LOADrS_frg[None, j], scale, row_max_safe) + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + + if cutlass.const_expr(not self.has_logits_transform): + if cutlass.const_expr(stage == 0): + sequence_producer_handle.commit() + else: + sequence_consumer_handle.release() + cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) + cute.arch.fence_view_async_tmem_store() + # Notify tensor core warp that softmax(S->P) is ready + si_handle.release() + + if cutlass.const_expr(not self.has_logits_transform): + vec_i_handle = si_corr_producer.acquire_and_advance() + ### di = di-1 * (e^(mi-1 - mi) * scale) + sum e^(xi*scale - mi*scale) + acc_scale_ = scale * (old_row_max - row_max_safe) + # * 0.5 compensates for initializing both packed elements with row_sum below + acc_scale = cute.arch.exp2(acc_scale_) * 0.5 + row_sum *= acc_scale + # 4-way unrolled reduction for ILP: 4 independent accumulator chains + # run in parallel, then tree-reduce. local_row_sum_0 is seeded with + # (row_sum, row_sum) so the old running sum folds into the reduction. + local_row_sum_0 = (row_sum, row_sum) + local_row_sum_1 = (0.0, 0.0) + local_row_sum_2 = (0.0, 0.0) + local_row_sum_3 = (0.0, 0.0) + + reduction_unroll = 4 + frg_tile_r = cute.size(tTMEM_LOADrS) // reduction_unroll + tTMEM_LOADrS_frg_r = cute.logical_divide( + tTMEM_LOADrS, cute.make_layout(frg_tile_r) + ) + + for j in cutlass.range_constexpr( + 0, cute.size(tTMEM_LOADrS_frg_r, mode=[0]), 2 + ): + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, + (tTMEM_LOADrS_frg_r[j, 0], tTMEM_LOADrS_frg_r[j + 1, 0]), + ) + local_row_sum_1 = cute.arch.add_packed_f32x2( + local_row_sum_1, + (tTMEM_LOADrS_frg_r[j, 1], tTMEM_LOADrS_frg_r[j + 1, 1]), + ) + local_row_sum_2 = cute.arch.add_packed_f32x2( + local_row_sum_2, + (tTMEM_LOADrS_frg_r[j, 2], tTMEM_LOADrS_frg_r[j + 1, 2]), + ) + local_row_sum_3 = cute.arch.add_packed_f32x2( + local_row_sum_3, + (tTMEM_LOADrS_frg_r[j, 3], tTMEM_LOADrS_frg_r[j + 1, 3]), + ) + + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, local_row_sum_1 + ) + local_row_sum_2 = cute.arch.add_packed_f32x2( + local_row_sum_2, local_row_sum_3 + ) + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, local_row_sum_2 + ) + row_sum = local_row_sum_0[0] + local_row_sum_0[1] + + return ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + + @cute.jit + def softmax_epilog( + self, + stage: int, + pv_thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + sO: cute.Tensor, + ): + """Final O scaling and SMEM write (transform path only). + + Mirrors CorrectionRole.epilog() but runs inside the softmax warpgroup + when there is no correction warp (has_logits_transform=True). + """ + assert self.o_dtype is not None + assert self.epi_tile is not None + pv_tiled_mma_shape = (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + cO_custom = cute.make_identity_tensor(pv_tiled_mma_shape) + + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = pv_thr_mma.partition_C(sO) + tOcO = pv_thr_mma.partition_C(cO) + tOcO_custom = pv_thr_mma.partition_C(cO_custom) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) + _ = cute.logical_divide(tOcO_custom, cute.make_layout((128, corr_tile_size))) + + tidx, _, _ = cute.arch.thread_idx() + num_warps = ( + len(self.softmax0_warp_ids) if stage == 0 else len(self.softmax1_warp_ids) + ) + thread_idx = tidx % (self.threads_per_warp * num_warps) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils.get_tmem_load_op( + self.pv_mma_tiler, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_copy_atom, tOtO_i[(None, None), 0] + ) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tTMEM_LOADoO = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] + tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] + tTMrO = cute.make_fragment( + tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype + ) + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) + for j in range(0, cute.size(tTMrO), 2): + tTMrO[j], tTMrO[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO[j], tTMrO[j + 1]), + (scale, scale), + ) + tSMrO = cute.make_fragment(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # For both softmax0 and softmax1 warp group + @cute.jit + def run( + self, + stage: int, + seqlen_q: Int32, + seqlen_k: Int32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + qk_thr_mma: cute.core.ThrMma, + pv_thr_mma: cute.core.ThrMma | None, + tStS: cute.Tensor, + tStSi: cute.Tensor, + tOtO: cute.Tensor | None, + sO: cute.Tensor | None, + params: cute.Tensor | None, + mma_si_consumer: PipelineConsumer, + si_corr_producer: PipelineProducer | None, + si_epi_producer: PipelineProducer | None, + s0_s1_sequence_consumer: PipelineConsumer, + s0_s1_sequence_producer: PipelineProducer, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + Handles softmax for either stage 0 or stage 1 (first or second half of + the Q tile). Loops over KV tiles, calling step() for each, coordinating + pipeline synchronization. + """ + assert self.o_dtype is not None + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % ( + self.threads_per_warp + * ( + len(self.softmax0_warp_ids) + if stage == 0 + else len(self.softmax1_warp_ids) + ) + ) + + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width + tScS = qk_thr_mma.partition_C(cS_base) + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + tmem_vec_offset = self.tmem_vec0_offset if stage == 0 else self.tmem_vec1_offset + tStS_vec = cute.make_tensor(tStS.iterator + tmem_vec_offset, tStS_vec_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tStS_P_layout = cute.composition( + tStS.layout, cute.make_layout((128, tilePlikeFP32)) + ) + tmem_p_offset = self.tmem_p0_offset if stage == 0 else self.tmem_p1_offset + tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) + thread_idx = tidx % ( + self.threads_per_warp + * ( + len(self.softmax0_warp_ids) + if stage == 0 + else len(self.softmax1_warp_ids) + ) + ) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_tmem_load.partition_S(tStSi) + tmem_store_vec_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec) + thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(thread_idx) + tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_q_ = seqlen_q + seqlen_k_ = seqlen_k + continue_cond = False + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q_ = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q_, + ) + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k_ = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + row_max = -Float32.inf + row_sum = 0.0 + value_args = (seqlen_q_, seqlen_k_, scale_softmax_log2) + atom_args = ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) + tensor_args = ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) + + kv_start_offset = ( + get_kv_start_block_idx( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k_, + seqlen_q_, + ) + * self.qk_mma_tiler[1] + ) + logical_offset = ( + curr_block_coord[0] * self.cta_tiler[0] + + stage * self.qk_mma_tiler[0], + kv_start_offset, + ) + cS = cute.domain_offset(logical_offset, cS_base) + vec_i_handle = None + if cutlass.const_expr(not self.has_logits_transform): + vec_i_handle = si_corr_producer.acquire_and_advance() + unmask_count = get_unmasked_trip_count( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k_, + seqlen_q_, + ) + batch_coord = curr_block_coord[2][1] + head_coord = curr_block_coord[2][0] + for i in cutlass.range(0, unmask_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = ( + cS_iter, + row_max, + row_sum, + vec_i_handle, + batch_coord, + head_coord, + ) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.step( + stage, + False, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + params, + ) + + mask_count = get_masked_trip_count( + self.mask_type, + self.window_left, + curr_block_coord, + self.cta_tiler, + seqlen_k_, + seqlen_q_, + ) + + for i in cutlass.range( + unmask_count, unmask_count + mask_count, 1, unroll=1 + ): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = ( + cS_iter, + row_max, + row_sum, + vec_i_handle, + batch_coord, + head_coord, + ) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.step( + stage, + True, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + params, + ) + si_handle = mma_si_consumer.wait_and_advance() + if cutlass.const_expr(not self.has_logits_transform): + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + tTMEM_STORE_VECrS[0] = row_sum + tTMEM_STORE_VECrS[1] = row_max + cute.copy( + tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS + ) + cute.arch.fence_view_async_tmem_store() + vec_i_handle.commit() + si_corr_producer.acquire() + si_handle.release() + else: + epi_handle = si_epi_producer.acquire_and_advance() + self.softmax_epilog( + stage, + pv_thr_mma, + tOtO, + scale_output, + sO, + ) + epi_handle.commit() + si_handle.release() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop diff --git a/flashinfer/cute_dsl/attention/roles/softmax_math.py b/flashinfer/cute_dsl/attention/roles/softmax_math.py new file mode 100644 index 0000000000..3db90cae6a --- /dev/null +++ b/flashinfer/cute_dsl/attention/roles/softmax_math.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Shared softmax math primitives for attention kernels. + +Used by FMHA (SoftmaxRole) to avoid +duplicating the core exp2-scale and packed row-sum reduction logic. +""" + +import cutlass +import cutlass.cute as cute + + +@cute.jit +def exp2_scale(scores, scale_log2, row_max): + """Apply e^((x - row_max) * scale_log2) to all paired elements in-place.""" + minus_max_scale = (0.0 - row_max) * scale_log2 + for i in cutlass.range_constexpr(0, cute.size(scores), 2): + scores[i], scores[i + 1] = cute.arch.fma_packed_f32x2( + (scores[i], scores[i + 1]), + (scale_log2, scale_log2), + (minus_max_scale, minus_max_scale), + ) + scores[i] = cute.arch.exp2(scores[i]) + scores[i + 1] = cute.arch.exp2(scores[i + 1]) + + +@cute.jit +def packed_row_sum(scores) -> tuple: + """Reduce all elements of a 1D register fragment via packed f32x2 adds. + + Returns (sum_even_indices, sum_odd_indices) tuple; caller typically + does ``vec[0] + vec[1]`` to get the scalar total. + """ + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(scores), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (scores[i], scores[i + 1]) + ) + return row_sum_vec diff --git a/flashinfer/cute_dsl/attention/scheduler/__init__.py b/flashinfer/cute_dsl/attention/scheduler/__init__.py new file mode 100644 index 0000000000..862626d5b3 --- /dev/null +++ b/flashinfer/cute_dsl/attention/scheduler/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# FMHA prefill scheduler +from .persistent import ( + FmhaStaticTileScheduler, + FmhaStaticTileSchedulerParams, + create_fmha_static_tile_scheduler, + create_fmha_static_tile_scheduler_params, +) + +# MLA decode scheduler +from .mla_persistent import ( + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler, + create_mla_static_tile_scheduler_params, + mla_get_split_kv, + mla_get_split_kv_simplified, + mla_get_workspace_size, +) diff --git a/flashinfer/mla/cute_dsl/mla_helpers.py b/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py similarity index 74% rename from flashinfer/mla/cute_dsl/mla_helpers.py rename to flashinfer/cute_dsl/attention/scheduler/mla_persistent.py index ac2bee49df..ff10dcd5d5 100644 --- a/flashinfer/mla/cute_dsl/mla_helpers.py +++ b/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py @@ -1,36 +1,28 @@ -# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: +"""MLA decode tile scheduler — moved from flashinfer/mla/cute_dsl/mla_helpers.py. -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. +Re-exports the tile scheduler classes and factory functions for use by the +modular MLA decode kernel and its roles. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Also provides host-side utility functions for split-KV computation. +""" +from __future__ import annotations import cutlass import cutlass.cute as cute +LOG2_E = 1.4426950408889634074 +MAX_SPLITS = 256 + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + class MLAStaticTileSchedulerParams: def __init__( self, @@ -46,18 +38,6 @@ def __init__( loc=None, ip=None, ): - """The static tile scheduler parameters prepared for MLA static tile scheduler. - - :param is_persistent: Whether to use persistent kernel mode - :type is_persistent: bool - :param problem_shape_b: The shape of the problem - :type problem_shape_b: cute.Int32 - :param problem_shape_s: The shape of the problem in sequence length Q dimension - :type problem_shape_s: cute.Int32 - :param cluster_shape_mnk: The shape of the cluster - :type cluster_shape_mnk: cute.Shape - :param split_kv: The scalar factor for split KV - """ self.is_persistent = is_persistent self.problem_shape_b = problem_shape_b self.problem_shape_s = problem_shape_s @@ -166,22 +146,6 @@ def __init__( loc=None, ip=None, ): - """The static tile scheduler for MLA split kv kernel. - Based on `is_persistent`, it provides 2 modes for use: - - Persistent mode: Launch fixed blocks and reschedule the data blocks. - - Non-persistent mode: Launch dynamic blocks and exit when the current work is done. - - :param params: The static tile scheduler parameters - :type params: MLAStaticTileSchedulerParams - :param current_work_linear_idx: The linear index of the current work - :type current_work_linear_idx: cutlass.Int32 - :param blk_coord: The coordinate of the current work - :type blk_coord: cute.Coord - :param grid_shape: The shape of the grid - :type grid_shape: cute.Shape - :param is_valid: Whether the current work is valid - :type is_valid: bool - """ self.params = params self.blk_coord = blk_coord self.grid_shape = grid_shape @@ -198,7 +162,6 @@ def __init__( ip=ip, ) self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip) - # Used for persistent scheduling self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) else: self.is_valid = is_valid @@ -213,7 +176,6 @@ def get_grid_shape( loc=None, ip=None, ) -> cute.Shape: - # called by host grid_shape = ( params.cluster_shape_mnk[0], params.problem_shape_b * params.problem_shape_s, @@ -295,10 +257,35 @@ def create_mla_static_tile_scheduler( return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) -LOG2_E = 1.4426950408889634074 -# avoid register indexing on array. -MAX_SPLITS = 256 +# --------------------------------------------------------------------------- +# Host-side utilities +# --------------------------------------------------------------------------- -def ceil_div(a: int, b: int) -> int: - return (a + b - 1) // b +def mla_get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int +) -> int: + """Get split_kv value for MLA kernel (host-side).""" + max_splits = ceil_div(K, mma_qk_tiler_mn[1]) + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + split_heur = min(max_splits, blocks_per_batch) + k_waves = ceil_div(max_splits, split_heur) + split_wave_aware = ceil_div(max_splits, k_waves) + max_split_kv = 32 + return min(split_wave_aware, max_split_kv) + + +def mla_get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: + """Simplified split_kv for MLA (host-side, no K dependency).""" + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + max_split_kv = 32 + return min(blocks_per_batch, max_split_kv) + + +def mla_get_workspace_size( + H: int, S: int, D: int, B: int, split_kv: int, acc_dtype_width: int +) -> int: + """Get workspace size in bytes for split-KV MLA decode.""" + if split_kv == 1: + return 0 + return B * H * S * split_kv * (D + 1) * acc_dtype_width // 8 diff --git a/flashinfer/cute_dsl/attention/scheduler/persistent.py b/flashinfer/cute_dsl/attention/scheduler/persistent.py new file mode 100644 index 0000000000..4c9c05cdc2 --- /dev/null +++ b/flashinfer/cute_dsl/attention/scheduler/persistent.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Persistent tile scheduler for FMHA kernels. + +Manages work distribution across CTAs, supporting both persistent and +non-persistent kernel modes. +""" + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass.cute.typing import Int32, Boolean + + +class FmhaStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.is_persistent, self.problem_shape_mbh]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.is_persistent, self.problem_shape_mbh], + self._values_pos, + strict=True, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + +def create_fmha_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_mbh: cute.Shape, +) -> FmhaStaticTileSchedulerParams: + return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) + + +class FmhaStaticTileScheduler: + def __init__( + self, + params: FmhaStaticTileSchedulerParams, + current_work_linear_idx: Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout( + params.problem_shape_mbh, loc=loc, ip=ip + ) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + if params.is_persistent: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + cutlass.min( + sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip) + ), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + return current_idx * q_tiler < seqlen_q + + def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + # cur_tile_coord is (mid, 0, (bid, hid)) + cur_tile_coord = ( + blk_coord[0], + 0, + (blk_coord[1], blk_coord[2]), + ) + + return utils.WorkTileInfo(cur_tile_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self._params) + values.extend(cutlass.extract_mlir_values(self._current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self._blk_coord)) + values.extend(cutlass.extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = cutlass.new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = cutlass.new_from_mlir_values(self._grid_shape, values[7:]) + return FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_fmha_static_tile_scheduler( + params: FmhaStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> FmhaStaticTileScheduler: + return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) diff --git a/flashinfer/cute_dsl/attention/tmem_layout.py b/flashinfer/cute_dsl/attention/tmem_layout.py new file mode 100644 index 0000000000..eb2460f91b --- /dev/null +++ b/flashinfer/cute_dsl/attention/tmem_layout.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""TmemLayout — computed TMEM allocation plan. + +Derives TMEM offsets from the AttentionConfig's tile shape instead of using +hardcoded magic numbers. The layout follows the pattern: + + S0 @ 0, S1 @ tile_m, O0 @ 2*tile_m, O1 @ 3*tile_m + P0 aliased inside S region at tile_m//4, P1 at tile_m + tile_m//4 + Vec buffers (row_max, row_sum) at start of S0 and S1 regions +""" + +from __future__ import annotations + +from dataclasses import dataclass + +from .config import AttentionConfig + + +@dataclass(frozen=True) +class TmemLayout: + """TMEM offset map for attention kernel score/output/P buffers.""" + + s0_offset: int + s1_offset: int + o0_offset: int + o1_offset: int + p0_offset: int + p1_offset: int + vec0_offset: int + vec1_offset: int + alloc_cols: int + + @staticmethod + def from_config(config: AttentionConfig) -> TmemLayout: + tile_m = config.mma_tiler[0] + SM100_TMEM_CAPACITY_COLUMNS = 512 + return TmemLayout( + s0_offset=0, + s1_offset=tile_m, + o0_offset=2 * tile_m, + o1_offset=3 * tile_m, + p0_offset=tile_m // 4, + p1_offset=tile_m + tile_m // 4, + vec0_offset=0, + vec1_offset=tile_m, + alloc_cols=SM100_TMEM_CAPACITY_COLUMNS, + ) diff --git a/flashinfer/cute_dsl/attention/warp_schedule.py b/flashinfer/cute_dsl/attention/warp_schedule.py new file mode 100644 index 0000000000..91a1288433 --- /dev/null +++ b/flashinfer/cute_dsl/attention/warp_schedule.py @@ -0,0 +1,109 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""WarpSchedule — warp role assignment and register budgets. + +Mirrors C++ CUTLASS's KernelSchedule concept (e.g. Sm100FmhaCtxKernelWarpspecializedSchedule). +Separates warp-to-role mapping and register allocation from the kernel and config, +making it swappable between FMHA and future attention variants. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + + +@dataclass(frozen=True) +class WarpSchedule: + """Defines warp role assignment and register budgets for attention kernels. + + Each field maps directly to C++ CUTLASS's KernelSchedule: + - Warp ID ranges for each role + - Register allocation per role (controls spill/occupancy tradeoff) + - Barrier IDs for CTA sync and TMEM allocation + """ + + softmax0_warp_ids: Tuple[int, ...] = (0, 1, 2, 3) + softmax1_warp_ids: Tuple[int, ...] = (4, 5, 6, 7) + correction_warp_ids: Tuple[int, ...] = (8, 9, 10, 11) + mma_warp_id: int = 12 + load_warp_id: int = 13 + epilogue_warp_id: int = 14 + empty_warp_id: int = 15 + + num_regs_softmax: int = 192 + num_regs_correction: int = 96 + num_regs_other: int = 32 + num_regs_empty: int = 24 + + threads_per_warp: int = 32 + cta_sync_bar_id: int = 0 + tmem_alloc_sync_bar_id: int = 1 + + @property + def softmax1_upper_warp_id(self) -> int: + """Upper bound warp ID for softmax1 dispatch (exclusive).""" + if self.correction_warp_ids: + return self.correction_warp_ids[0] + return self.mma_warp_id + + @property + def all_warp_ids(self) -> Tuple[int, ...]: + return ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epilogue_warp_id, + self.empty_warp_id, + ) + + @property + def num_warps(self) -> int: + return len(self.all_warp_ids) + + @property + def threads_per_cta(self) -> int: + return self.threads_per_warp * self.num_warps + + @property + def num_warps_per_warpgroup(self) -> int: + return 4 + + @property + def softmax_warpgroup_count(self) -> int: + total_softmax_warps = len(self.softmax0_warp_ids) + len(self.softmax1_warp_ids) + return total_softmax_warps // self.num_warps_per_warpgroup + + @property + def tmem_dealloc_arrive_count(self) -> int: + """Number of threads that must arrive at the TMEM dealloc barrier.""" + return self.threads_per_warp * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ) + + +PREFILL_SCHEDULE = WarpSchedule() + +# Schedule for has_logits_transform variants (e.g. sigmoid attention): +# correction warps removed, freeing 4 warps and their registers. +# Softmax warps take over the epilog (TMEM→scale→SMEM) after their KV loop. +PREFILL_TRANSFORM_SCHEDULE = WarpSchedule( + softmax0_warp_ids=(0, 1, 2, 3), + softmax1_warp_ids=(4, 5, 6, 7), + correction_warp_ids=(), + mma_warp_id=8, + load_warp_id=9, + epilogue_warp_id=10, + empty_warp_id=11, + num_regs_softmax=192, + num_regs_correction=96, + num_regs_other=32, + num_regs_empty=24, +) diff --git a/flashinfer/cute_dsl/attention/wrappers/__init__.py b/flashinfer/cute_dsl/attention/wrappers/__init__.py new file mode 100644 index 0000000000..c80c1365ac --- /dev/null +++ b/flashinfer/cute_dsl/attention/wrappers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from .batch_prefill import BatchPrefillCuteDSLWrapper +from .batch_mla import BatchMLADecodeCuteDSLWrapper, cute_dsl_mla_decode diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py new file mode 100644 index 0000000000..333697ab18 --- /dev/null +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -0,0 +1,883 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""BatchMLADecodeCuteDSLWrapper — PyTorch-facing API for MLA decode attention. + +Constructs MLAConfig from user-facing parameters, compiles the modular +BlackwellMultiLatentAttentionForward kernel, and provides plan()/run(). + +Also re-exports a standalone `cute_dsl_mla_decode` function that mirrors the +original integration layer in flashinfer.mla.cute_dsl.mla_decode but uses the +modular kernel. +""" + +import functools +from typing import Callable, Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 + +from flashinfer.api_logging import flashinfer_api +from flashinfer.utils import device_support_pdl +from flashinfer.cute_dsl.utils import ( + get_max_active_clusters, + get_num_sm, + torch_to_cutlass_dtype, +) + +from ..config import AttentionFusion +from ..fusion.variant import AttentionVariant, StandardAttention +from ..mla_decode import BlackwellMultiLatentAttentionForward +from ..mla_decode_fp8 import BlackwellMultiLatentAttentionForwardFP8 +from ..mla_config import MLAConfig + + +# --------------------------------------------------------------------------- +# Cached helpers (deterministic for the same args ⇒ safe to @functools.cache) +# --------------------------------------------------------------------------- + + +@functools.cache +def _get_split_kv_and_workspace_size( + B: int, + q_len: int, + H: int, + kv_lora_rank: int, + max_active_blocks: int, +) -> Tuple[int, int]: + """Cache split_kv and workspace_size since they are deterministic for the same params.""" + split_kv = BlackwellMultiLatentAttentionForward.get_split_kv_simplified( + B, q_len, max_active_blocks + ) + workspace_size = BlackwellMultiLatentAttentionForward.get_workspace_size( + H, q_len, kv_lora_rank, B, split_kv, cutlass.Float32 + ) + return split_kv, workspace_size + + +@functools.cache +def _check_can_implement( + torch_dtype: torch.dtype, + torch_out_dtype: torch.dtype, + page_size: int, + num_heads: int, + seq_len_q: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, +) -> None: + """Check if the kernel supports the given configuration (cached).""" + mma_qk_tiler_mn = (128, 128) + mma_pv_tiler_mn = (128, 256) + + is_fp8 = torch_dtype == torch.float8_e4m3fn + KernelClass = ( + BlackwellMultiLatentAttentionForwardFP8 + if is_fp8 + else BlackwellMultiLatentAttentionForward + ) + cutlass_in_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) + if not KernelClass.can_implement( + 1, # B (runtime, use placeholder) + seq_len_q, + 1, # K (runtime, use placeholder) + num_heads, + kv_lora_rank, + qk_rope_head_dim, + cutlass_in_dtype, + cutlass_out_dtype, + cutlass.Float32, + cutlass.Float32, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + 1, # split_kv (runtime, use 1 to pass the H<128 check) + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise ValueError( + f"cute_dsl_mla_decode: unsupported configuration " + f"(q_len={seq_len_q}, num_heads={num_heads}, page_size={page_size}, " + f"in_dtype={torch_dtype}, out_dtype={torch_out_dtype})" + ) + + +def _make_mla_fake_tensors( + cutlass_dtype, + cutlass_out_dtype, + is_workspace_size_zero: bool, + is_var_split_kv: bool, +): + """Create fake tensors for MLA kernel compilation (shared by all paths).""" + sym_heads = cute.sym_int() + sym_latent = cute.sym_int(divisibility=16) + sym_seq_q = cute.sym_int() + sym_rope = cute.sym_int(divisibility=16) + sym_batch = cute.sym_int() + sym_kv_batch = cute.sym_int() + sym_seq_kv = cute.sym_int() + sym_page_count = cute.sym_int() + sym_workspace_size = cute.sym_int() + + q_latent_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_batch, sym_seq_q, sym_heads, sym_latent), + stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + q_rope_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_batch, sym_seq_q, sym_heads, sym_rope), + stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + c_latent_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_kv_batch, sym_seq_kv, sym_latent), + stride=(cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + c_rope_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_kv_batch, sym_seq_kv, sym_rope), + stride=(cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + page_table_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch, sym_page_count), + stride_order=(1, 0), + assumed_align=16, + ) + o_fake = cute.runtime.make_fake_compact_tensor( + cutlass_out_dtype, + (sym_batch, sym_seq_q, sym_heads, sym_latent), + stride_order=(3, 2, 1, 0), + assumed_align=16, + ) + lse_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + (sym_batch, sym_seq_q, sym_heads), + stride_order=(2, 1, 0), + assumed_align=16, + ) + if is_workspace_size_zero: + workspace_fake = None + else: + workspace_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int8, + (sym_workspace_size,), + assumed_align=32, + ) + cache_seqs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=16, + ) + if is_var_split_kv: + block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=16, + ) + else: + block_split_kvs_fake = None + + return ( + q_latent_fake, + q_rope_fake, + c_latent_fake, + c_rope_fake, + page_table_fake, + o_fake, + lse_fake, + workspace_fake, + cache_seqs_fake, + block_split_kvs_fake, + ) + + +def _make_mla_config( + kv_lora_rank: int, + qk_rope_head_dim: int, + page_size: int, + skip_correction_threshold: float, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + enable_pdl: bool, + is_fp8: bool, +) -> MLAConfig: + """Create an MLAConfig with standard tiler settings.""" + cluster_shape_mnk = (2, 1, 1) + return MLAConfig( + latent_dim=kv_lora_rank, + rope_dim=qk_rope_head_dim, + acc_dtype=cutlass.Float32, + lse_dtype=cutlass.Float32, + mma_qk_tiler_mn=(128, 128), + mma_pv_tiler_mn=(128, 256), + max_active_clusters=get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ), + page_size=page_size, + skip_correction_threshold=skip_correction_threshold, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + enable_pdl=enable_pdl, + is_fp8=is_fp8, + mma_o_stage=2 if is_fp8 else 1, + ) + + +@functools.cache +def _compile_mla_kernel( + torch_dtype: torch.dtype, + torch_out_dtype: torch.dtype, + page_size: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + skip_correction_threshold: float = 0.0, + is_workspace_size_zero: bool = False, + enable_pdl: bool = False, + variant: Optional[AttentionVariant] = None, + params_shape: Optional[tuple] = None, +) -> Callable: + """Compile and cache an MLA decode kernel (standard or variant). + + Uses ``@functools.cache`` so repeated calls with the same arguments + return the previously compiled kernel in microseconds rather than + recompiling (~3 s). For standard attention pass ``variant=None`` + (the default); for custom variants pass the variant instance (hashable + by identity). + + ``AttentionFusion`` is constructed *inside* this function so it never + appears in the cache key (it is unhashable). + """ + if variant is None: + variant = StandardAttention() + fusion = AttentionFusion(variant=variant) + + cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) + + is_fp8 = torch_dtype == torch.float8_e4m3fn + config = _make_mla_config( + kv_lora_rank, + qk_rope_head_dim, + page_size, + skip_correction_threshold, + is_persistent, + is_var_seq, + is_var_split_kv, + enable_pdl, + is_fp8, + ) + + kernel_obj = ( + BlackwellMultiLatentAttentionForwardFP8(config, fusion=fusion) + if is_fp8 + else BlackwellMultiLatentAttentionForward(config, fusion=fusion) + ) + + fakes = _make_mla_fake_tensors( + cutlass_dtype, + cutlass_out_dtype, + is_workspace_size_zero, + is_var_split_kv, + ) + ( + q_latent_fake, + q_rope_fake, + c_latent_fake, + c_rope_fake, + page_table_fake, + o_fake, + lse_fake, + workspace_fake, + cache_seqs_fake, + block_split_kvs_fake, + ) = fakes + + params_fake = None + if params_shape is not None: + ndim = len(params_shape) + stride_order = tuple(range(ndim - 1, -1, -1)) + params_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + params_shape, + stride_order=stride_order, + assumed_align=16, + ) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + q_latent_fake, + q_rope_fake, + c_latent_fake, + c_rope_fake, + page_table_fake, + o_fake, + lse_fake, + workspace_fake, + Int32(1), # split_kv placeholder + cache_seqs_fake, + block_split_kvs_fake, + Float32(1.0), # softmax_scale placeholder + Float32(1.0), # output_scale placeholder + params_fake, + stream_fake, + options="--enable-tvm-ffi --opt-level 2", + ) + + return compiled_kernel + + +# --------------------------------------------------------------------------- +# BatchMLADecodeCuteDSLWrapper — stateful plan()/run() interface +# --------------------------------------------------------------------------- + + +class BatchMLADecodeCuteDSLWrapper: + """PyTorch-facing wrapper for the modular MLA decode kernel. + + Usage:: + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=512, qk_rope_head_dim=64, num_heads=128, + page_size=64, q_dtype=torch.bfloat16, + ) + out = wrapper.run(query, kv_cache, block_tables, seq_lens, max_seq_len, + softmax_scale=0.125) + """ + + @flashinfer_api + def __init__(self, workspace_buffer: torch.Tensor) -> None: + assert workspace_buffer.dtype == torch.int8, ( + f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" + ) + self._workspace_buffer = workspace_buffer + self._device = workspace_buffer.device + self._compiled_kernel: Optional[Callable] = None + + @flashinfer_api + def plan( + self, + kv_lora_rank: int = 512, + qk_rope_head_dim: int = 64, + num_heads: int = 128, + page_size: int = 1, + q_dtype: torch.dtype = torch.bfloat16, + out_dtype: Optional[torch.dtype] = None, + is_var_seq: bool = True, + enable_pdl: Optional[bool] = None, + variant: Optional[AttentionVariant] = None, + ) -> None: + """Compile (or retrieve cached) MLA decode kernel for the given config. + + Parameters + ---------- + kv_lora_rank : int + Latent dimension (e.g. 512). + qk_rope_head_dim : int + RoPE dimension (e.g. 64). + num_heads : int + Number of attention heads (typically 128 for DeepSeek-V3). + page_size : int + KV cache page size. + q_dtype : torch.dtype + Query/KV data type (float16 or bfloat16). + out_dtype : Optional[torch.dtype] + Output data type. Defaults to same as q_dtype. + is_var_seq : bool + Whether sequence lengths vary across the batch. + enable_pdl : Optional[bool] + Whether to enable Programmatic Dependent Launch. Auto-detects if None. + variant : Optional[AttentionVariant] + Attention variant (ALiBi, SoftCapping, AttentionWithSink, etc.). + None uses standard softmax attention. + """ + self._kv_lora_rank = kv_lora_rank + self._qk_rope_head_dim = qk_rope_head_dim + self._num_heads = num_heads + self._page_size = page_size + self._q_dtype = q_dtype + if out_dtype is not None: + self._o_dtype = out_dtype + elif q_dtype == torch.float8_e4m3fn: + self._o_dtype = torch.bfloat16 + else: + self._o_dtype = q_dtype + self._is_var_seq = is_var_seq + self._is_persistent = not is_var_seq + self._is_var_split_kv = False + self._skip_correction_threshold = 0.0 + + self._enable_pdl = ( + device_support_pdl(self._device) if enable_pdl is None else enable_pdl + ) + + if variant is None: + variant = StandardAttention() + self._variant = variant + + if self._variant.has_logits_transform: + raise ValueError( + "MLA decode does not support logits_transform. " + "Use score_mod, update_statistics, or transform_output instead." + ) + + self._has_params = self._variant.extra_params is not None + if self._has_params: + ep = self._variant.extra_params.to(torch.float32).to(self._device) + if not ep.is_contiguous(): + raise ValueError( + f"AttentionVariant.extra_params must be contiguous, " + f"got strides {ep.stride()} for shape {ep.shape}. " + f"Call .contiguous() before returning from extra_params." + ) + self._params_torch = ep + else: + self._params_torch = None + + _check_can_implement( + torch_dtype=self._q_dtype, + torch_out_dtype=self._o_dtype, + page_size=self._page_size, + num_heads=self._num_heads, + seq_len_q=1, + kv_lora_rank=self._kv_lora_rank, + qk_rope_head_dim=self._qk_rope_head_dim, + is_persistent=self._is_persistent, + is_var_seq=self._is_var_seq, + is_var_split_kv=self._is_var_split_kv, + ) + + self._cache_variant = ( + self._variant if not isinstance(self._variant, StandardAttention) else None + ) + self._params_shape = ( + tuple(self._params_torch.shape) if self._has_params else None + ) + + self._compiled_kernel = _compile_mla_kernel( + torch_dtype=self._q_dtype, + torch_out_dtype=self._o_dtype, + page_size=self._page_size, + kv_lora_rank=self._kv_lora_rank, + qk_rope_head_dim=self._qk_rope_head_dim, + is_persistent=self._is_persistent, + is_var_seq=self._is_var_seq, + is_var_split_kv=self._is_var_split_kv, + skip_correction_threshold=self._skip_correction_threshold, + is_workspace_size_zero=False, + enable_pdl=self._enable_pdl, + variant=self._cache_variant, + params_shape=self._params_shape, + ) + + def _validate_run_inputs( + self, + q: torch.Tensor, + kv_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + out: Optional[torch.Tensor], + ) -> None: + """Check that run() inputs are consistent with the plan() configuration.""" + expected_D = self._kv_lora_rank + self._qk_rope_head_dim + if q.shape[-1] != expected_D: + raise ValueError( + f"q.shape[-1]={q.shape[-1]} does not match the planned " + f"kv_lora_rank + qk_rope_head_dim = {expected_D}" + ) + if q.dtype != self._q_dtype: + raise ValueError( + f"q.dtype={q.dtype} does not match the planned q_dtype={self._q_dtype}" + ) + if kv_cache.dtype != self._q_dtype: + raise ValueError( + f"kv_cache.dtype={kv_cache.dtype} does not match the planned " + f"q_dtype={self._q_dtype}" + ) + if out is not None and out.dtype != self._o_dtype: + raise ValueError( + f"out.dtype={out.dtype} does not match the planned " + f"out_dtype={self._o_dtype}" + ) + + @flashinfer_api + def run( + self, + q: torch.Tensor, + kv_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + softmax_scale: float, + output_scale: float = 1.0, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run the MLA decode kernel. + + Parameters + ---------- + q : torch.Tensor + [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim. + kv_cache : torch.Tensor + [num_pages, page_size, D_total] (3D) or [num_pages, 1, page_size, D_total] (4D). + block_tables : torch.Tensor + [B, max_pages] page table indices. + seq_lens : torch.Tensor + [B] per-request KV sequence lengths. + max_seq_len : int + Maximum sequence length across the batch. + softmax_scale : float + Scale factor for QK^T before softmax. + output_scale : float + Scale factor applied to the output. + out : Optional[torch.Tensor] + Pre-allocated output [B, q_len, H, kv_lora_rank]. + + Returns + ------- + torch.Tensor + Output tensor [B, q_len, H, kv_lora_rank]. + """ + if self._compiled_kernel is None: + raise RuntimeError("Call plan() before run().") + + self._validate_run_inputs(q, kv_cache, block_tables, seq_lens, out) + + B, q_len, H, D_qk = q.shape + + # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] + if kv_cache.dim() == 4: + if kv_cache.shape[1] != 1: + raise ValueError( + f"Expected 4D kv_cache shape [num_pages, 1, page_size, D], " + f"got {tuple(kv_cache.shape)}" + ) + kv_cache = kv_cache.squeeze(1) + elif kv_cache.dim() != 3: + raise ValueError(f"kv_cache must be 3D or 4D, got ndim={kv_cache.dim()}") + + # Split query into latent and rope components + q_latent_k = q[..., : self._kv_lora_rank] + q_rope_k = q[..., self._kv_lora_rank :] + + # KV cache slices + c_latent_k = kv_cache[:, :, : self._kv_lora_rank] + c_rope_k = kv_cache[:, :, self._kv_lora_rank :] + + page_table_k = block_tables + + # Compute split_kv and workspace size + max_active_blocks = get_num_sm(q.device) + split_kv, workspace_size = _get_split_kv_and_workspace_size( + B, q_len, H, self._kv_lora_rank, max_active_blocks + ) + + if H < 128 and split_kv != 1: + raise ValueError( + f"num_heads={H} < 128 requires split_kv==1, got split_kv={split_kv}" + ) + + # Prepare workspace + is_workspace_size_zero = workspace_size == 0 + if is_workspace_size_zero: + workspace_bytes = None + else: + if self._workspace_buffer.numel() < workspace_size: + raise ValueError( + f"workspace_buffer too small: {self._workspace_buffer.numel()} bytes, " + f"need {workspace_size} bytes" + ) + workspace_bytes = self._workspace_buffer[:workspace_size] + + # Re-compile if workspace-zero-ness changed from what was planned + compiled_kernel = self._compiled_kernel + if is_workspace_size_zero: + compiled_kernel = _compile_mla_kernel( + torch_dtype=self._q_dtype, + torch_out_dtype=self._o_dtype, + page_size=self._page_size, + kv_lora_rank=self._kv_lora_rank, + qk_rope_head_dim=self._qk_rope_head_dim, + is_persistent=self._is_persistent, + is_var_seq=self._is_var_seq, + is_var_split_kv=self._is_var_split_kv, + skip_correction_threshold=self._skip_correction_threshold, + is_workspace_size_zero=True, + enable_pdl=self._enable_pdl, + variant=self._cache_variant, + params_shape=self._params_shape, + ) + + # Output buffer + if out is None: + out = torch.empty( + (B, q_len, H, self._kv_lora_rank), + dtype=self._o_dtype, + device=q.device, + ) + o_k = out + + # LSE buffer + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=q.device) + + # cache_seqs: per-batch sequence lengths + cache_seqs = ( + seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) + ) + + block_split_kvs = None + + compiled_kernel( + q_latent_k, + q_rope_k, + c_latent_k, + c_rope_k, + page_table_k, + o_k, + lse_k, + workspace_bytes, + Int32(split_kv), + cache_seqs, + block_split_kvs, + Float32(softmax_scale), + Float32(output_scale), + self._params_torch if self._has_params else None, + ) + + return out + + +# --------------------------------------------------------------------------- +# Standalone function — drop-in replacement for the original integration layer +# --------------------------------------------------------------------------- + + +def cute_dsl_mla_decode( + query: torch.Tensor, + kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + softmax_scale: float, + output_scale: float = 1.0, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + is_var_seq: bool = True, + enable_pdl: Optional[bool] = None, +) -> torch.Tensor: + """CuTe DSL MLA decode kernel for Blackwell SM100 (modular variant). + + Parameters + ---------- + query : torch.Tensor + [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim + kv_cache : torch.Tensor + [num_pages, page_size, D_ckv + D_kpe] (3D) or [num_pages, 1, page_size, D_ckv + D_kpe] (4D) + workspace_buffer : torch.Tensor + Pre-allocated workspace buffer (int8). Required size depends on batch size + and split_kv (auto-computed from B, q_len, and number of SMs): + + - Formula: ``B * H * q_len * split_kv * (kv_lora_rank + 1) * 4`` bytes + (0 when split_kv == 1, which happens when B >= num_SMs / 2) + - Typical max: ~18 MB on a 148-SM GPU (e.g. B=4..8, H=128, D=512) + - Safe default: 128 MB covers all realistic configurations + kv_lora_rank : int + Latent dimension (e.g. 512). + qk_rope_head_dim : int + RoPE dimension (e.g. 64). + block_tables : torch.Tensor + [B, max_pages] — page table indices. + seq_lens : torch.Tensor + [B] — per-request KV sequence lengths. + max_seq_len : int + Maximum sequence length across the batch. + softmax_scale : float + Scale factor for QK^T before softmax. + output_scale : float + Scale factor applied to the output. + out : Optional[torch.Tensor] + Pre-allocated output tensor [B, q_len, H, kv_lora_rank]. + out_dtype : Optional[torch.dtype] + Output data type. If None, defaults to same as input dtype. + is_var_seq : bool + Whether the sequence length is variable across the batch. + enable_pdl : Optional[bool], default=None + Whether to enable Programmatic Dependent Launch (PDL). + If None, auto-detects based on device capability. + + Returns + ------- + torch.Tensor + Output tensor [B, q_len, H, kv_lora_rank]. + """ + supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} + assert query.dtype in supported_dtypes, ( + f"cute_dsl_mla_decode only supports {supported_dtypes}, got {query.dtype}" + ) + assert kv_cache.dtype == query.dtype, ( + f"kv_cache dtype {kv_cache.dtype} must match query dtype {query.dtype}" + ) + B, q_len, H, D_qk = query.shape + assert D_qk == kv_lora_rank + qk_rope_head_dim + + q_dtype = query.dtype + if out is not None: + o_dtype = out.dtype + elif out_dtype is not None: + o_dtype = out_dtype + elif q_dtype == torch.float8_e4m3fn: + o_dtype = torch.bfloat16 + else: + o_dtype = q_dtype + + # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] + if kv_cache.dim() == 4: + if kv_cache.shape[1] != 1: + raise ValueError( + f"Expected 4D kv_cache shape [num_pages, 1, page_size, D], " + f"got {tuple(kv_cache.shape)}" + ) + kv_cache = kv_cache.squeeze(1) + elif kv_cache.dim() != 3: + raise ValueError(f"kv_cache must be 3D or 4D, got ndim={kv_cache.dim()}") + page_size = kv_cache.shape[1] + + # Split query into latent and rope components + q_latent_k = query[..., :kv_lora_rank] + q_rope_k = query[..., kv_lora_rank:] + + # KV cache slices + c_latent_k = kv_cache[:, :, :kv_lora_rank] + c_rope_k = kv_cache[:, :, kv_lora_rank:] + + page_table_k = block_tables + + # Runtime validation + if max_seq_len <= 0: + raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") + if H < 128 and H != 1: + raise ValueError( + f"cute_dsl_mla_decode requires num_heads >= 128 (or 1 for reduction), got {H}" + ) + + # Cached split_kv and workspace_size computation + max_active_blocks = get_num_sm(query.device) + split_kv, workspace_size = _get_split_kv_and_workspace_size( + B, q_len, H, kv_lora_rank, max_active_blocks + ) + + if H < 128 and split_kv != 1: + raise ValueError( + f"cute_dsl_mla_decode: num_heads={H} < 128 requires split_kv==1, " + f"got split_kv={split_kv}" + ) + + # Prepare workspace + assert workspace_buffer.dtype == torch.int8, ( + f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" + ) + assert workspace_buffer.numel() >= workspace_size, ( + f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " + f"need {workspace_size} bytes" + ) + is_workspace_size_zero = workspace_size == 0 + if is_workspace_size_zero: + workspace_bytes = None + else: + workspace_bytes = workspace_buffer[:workspace_size] + + # Output buffer + if out is not None: + o_k = out + else: + o_k = torch.empty( + (B, q_len, H, kv_lora_rank), dtype=o_dtype, device=query.device + ) + + # LSE buffer + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) + + # cache_seqs: per-batch sequence lengths + cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) + + is_var_split_kv = False + block_split_kvs = None + skip_correction_threshold = 0.0 + + is_persistent = not is_var_seq + + # Validate configuration (cached, negligible overhead after first call) + _check_can_implement( + torch_dtype=q_dtype, + torch_out_dtype=o_dtype, + page_size=page_size, + num_heads=H, + seq_len_q=q_len, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + ) + + enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + + # Get compiled kernel (cached after first compile) + compiled_kernel = _compile_mla_kernel( + torch_dtype=q_dtype, + torch_out_dtype=o_dtype, + page_size=page_size, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + skip_correction_threshold=skip_correction_threshold, + is_workspace_size_zero=is_workspace_size_zero, + enable_pdl=enable_pdl, + ) + + # Call the kernel + compiled_kernel( + q_latent_k, + q_rope_k, + c_latent_k, + c_rope_k, + page_table_k, + o_k, + lse_k, + workspace_bytes, + Int32(split_kv), + cache_seqs, + block_split_kvs, + Float32(softmax_scale), + Float32(output_scale), + None, # params_in (no variant in standalone function) + ) + + if out is not None: + return out + + return o_k diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py new file mode 100644 index 0000000000..58a24abe69 --- /dev/null +++ b/flashinfer/cute_dsl/attention/wrappers/batch_prefill.py @@ -0,0 +1,423 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""BatchPrefillCuteDSLWrapper — PyTorch-facing API for batch prefill attention. + +Constructs AttentionConfig + AttentionFusion from user-facing parameters, +creates the kernel, compiles it via TVM-FFI, and provides the run() interface. +Compilation is memoized via @functools.cache with symbolic tensor dimensions, +so kernels are compiled once per (dtype, heads, head_dim, mask, variant) combo +and reused across batches of any size. +""" + +import functools +import math +from typing import Optional + +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Int32 + +from flashinfer.api_logging import flashinfer_api + +from ..config import AttentionConfig, AttentionFusion +from ..fusion.mask import MaskType +from ..fusion.variant import AttentionVariant, StandardAttention +from ..prefill import BlackwellFusedMultiHeadAttentionForward + + +@functools.cache +def _get_compiled_prefill_kernel( + in_dtype, + out_dtype, + num_qo_heads, + num_kv_heads, + head_dim, + mask_type, + window_left, + is_persistent, + variant, + params_shape, +): + """Compile and cache the prefill kernel. + + Uses symbolic dimensions for sequence lengths and batch size so the same + compiled kernel can be reused across different batch shapes. Pass + ``variant=None`` for standard attention (always cache-hits); pass the + actual variant instance for custom variants (hashable by identity). + + ``AttentionFusion`` is constructed *inside* this function so it never + appears in the cache key (it is unhashable). + """ + if variant is None: + variant = StandardAttention() + fusion = AttentionFusion(variant=variant) + h_r = num_qo_heads // num_kv_heads + + config = AttentionConfig( + qk_acc_dtype=cutlass.Float32, + pv_acc_dtype=cutlass.Float32, + mma_tiler=(128, 128, head_dim), + is_persistent=is_persistent, + mask_type=mask_type, + num_repeat_kv_heads=h_r, + window_left=window_left, + ) + _dtype_width_map = { + cutlass.Float16: 16, + cutlass.BFloat16: 16, + cutlass.Float8E4M3FN: 8, + } + config.can_implement(dtype_width=_dtype_width_map[in_dtype]) + fmha = BlackwellFusedMultiHeadAttentionForward(config, fusion) + + sym_s_q = cute.sym_int() + sym_s_k = cute.sym_int() + sym_batch_p1 = cute.sym_int() + + q_fake = cute.runtime.make_fake_compact_tensor( + in_dtype, + (sym_s_q, num_qo_heads, head_dim), + stride_order=(2, 1, 0), + assumed_align=16, + ) + k_fake = cute.runtime.make_fake_compact_tensor( + in_dtype, + (sym_s_k, num_kv_heads, head_dim), + stride_order=(2, 1, 0), + assumed_align=16, + ) + v_fake = cute.runtime.make_fake_compact_tensor( + in_dtype, + (sym_s_k, num_kv_heads, head_dim), + stride_order=(2, 1, 0), + assumed_align=16, + ) + o_fake = cute.runtime.make_fake_compact_tensor( + out_dtype, + (sym_s_q, num_qo_heads, head_dim), + stride_order=(2, 1, 0), + assumed_align=16, + ) + cum_seqlen_q_fake = cute.runtime.make_fake_compact_tensor( + Int32, + (sym_batch_p1,), + assumed_align=16, + ) + cum_seqlen_k_fake = cute.runtime.make_fake_compact_tensor( + Int32, + (sym_batch_p1,), + assumed_align=16, + ) + + params_fake = None + if params_shape is not None: + ndim = len(params_shape) + stride_order = tuple(range(ndim - 1, -1, -1)) + params_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + params_shape, + stride_order=stride_order, + assumed_align=16, + ) + + problem_size = (1, 1, 1, num_qo_heads, num_kv_heads, head_dim) + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + return cute.compile( + fmha, + q_fake, + k_fake, + v_fake, + o_fake, + problem_size, + cum_seqlen_q_fake, + 1, + cum_seqlen_k_fake, + 1, + 0.0, + 1.0, + params_fake, + stream_fake, + options="--enable-tvm-ffi --opt-level 2", + ) + + +class BatchPrefillCuteDSLWrapper: + @flashinfer_api + def __init__( + self, + float_workspace_buffer: torch.Tensor, + use_cuda_graph: bool = False, + ) -> None: + # Named float_workspace_buffer for compatibility with the parent + # BatchPrefillWithRaggedKVCacheWrapper API. Callers typically pass + # torch.uint8; the CuTe DSL kernel does not use this buffer. + self._float_workspace_buffer = float_workspace_buffer + self.device = float_workspace_buffer.device + + self._use_cuda_graph = use_cuda_graph + + self._in_dtype = None + self._out_dtype = None + self._compiled_fmha = None + + @flashinfer_api + def plan( + self, + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=None, + causal=True, + sm_scale=1.0, + q_data_type=torch.float16, + kv_data_type=torch.float16, + window_left: int = -1, + variant: AttentionVariant | None = None, + ) -> None: + """Compile the FMHA prefill kernel for the given configuration. + + Parameters + ---------- + qo_indptr : torch.Tensor + Cumulative query sequence lengths, shape [batch_size + 1]. + kv_indptr : torch.Tensor + Cumulative KV sequence lengths, shape [batch_size + 1]. + num_qo_heads : int + Number of query/output heads. + num_kv_heads : int + Number of key/value heads (must divide num_qo_heads). + head_dim_qk : int + Head dimension for queries and keys. + head_dim_vo : Optional[int] + Head dimension for values and output. Must equal head_dim_qk if set. + causal : bool + Whether to apply causal masking. + sm_scale : float + Softmax scale factor (typically 1/sqrt(head_dim)). + q_data_type : torch.dtype + Data type for queries (float16, bfloat16, or float8_e4m3fn). + kv_data_type : torch.dtype + Data type for keys/values. + window_left : int + Sliding window size. -1 disables sliding window. + variant : Optional[AttentionVariant] + Attention variant (ALiBi, RPE, Sigmoid, etc.). None uses standard softmax. + """ + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + self._batch_size = qo_indptr.shape[0] - 1 + self._num_qo_heads = num_qo_heads + self._num_kv_heads = num_kv_heads + assert num_qo_heads % num_kv_heads == 0, ( + "num_qo_heads must be divisible by num_kv_heads" + ) + self._head_dim = head_dim_qk + assert head_dim_vo is None or head_dim_vo == head_dim_qk, ( + "head_dim_vo must be None or equal to head_dim_qk" + ) + self._causal = causal + self._sm_scale = sm_scale + self._device = qo_indptr.device + self._is_persistent = True + + if variant is None: + variant = StandardAttention() + self._variant = variant + + self._q_data_type = q_data_type + + # Map torch dtype → cutlass dtype + _dtype_map = { + torch.float16: (cutlass.Float16, cutlass.Float16), + torch.bfloat16: (cutlass.BFloat16, cutlass.BFloat16), + torch.float8_e4m3fn: (cutlass.Float8E4M3FN, cutlass.Float16), + } + if q_data_type not in _dtype_map: + raise ValueError(f"Unsupported input data type: {q_data_type}") + self._in_dtype, self._out_dtype = _dtype_map[q_data_type] + + # Sequence lengths from indptr + s_q = qo_indptr[1:] - qo_indptr[:-1] + s_k = kv_indptr[1:] - kv_indptr[:-1] + s_q_all = int(qo_indptr[-1].item()) + s_k_all = int(kv_indptr[-1].item()) + max_s_q = int(torch.max(s_q).item()) + max_s_k = int(torch.max(s_k).item()) + + # Store for runtime + self._qo_indptr = qo_indptr.to(torch.int32) + self._kv_indptr = kv_indptr.to(torch.int32) + self._s_q_all = s_q_all + self._s_k_all = s_k_all + self._o_padding = max_s_q + + self._has_params = self._variant.extra_params is not None + if self._has_params: + ep = self._variant.extra_params.to(torch.float32).to(self._device) + if not ep.is_contiguous(): + raise ValueError( + f"AttentionVariant.extra_params must be contiguous, " + f"got strides {ep.stride()} for shape {ep.shape}. " + f"Call .contiguous() before returning from extra_params." + ) + self._params_torch = ep + + mma_tiler_n = 128 + + # Determine mask type + self._mask_type = MaskType.NO_MASK + if self._causal: + self._mask_type = MaskType.CAUSAL_MASK + elif window_left > 0: + self._mask_type = MaskType.SLIDING_WINDOW_MASK + else: + if torch.any(s_k % mma_tiler_n != 0).item(): + self._mask_type = MaskType.RESIDUAL_MASK + + self._problem_size = ( + self._batch_size, + max_s_q, + max_s_k, + self._num_qo_heads, + self._num_kv_heads, + self._head_dim, + ) + + log2_e = math.log2(math.exp(1.0)) + self._scale_softmax_log2 = self._sm_scale * log2_e + self._scale_output = 1.0 + + cache_variant = ( + self._variant if not isinstance(self._variant, StandardAttention) else None + ) + params_shape = tuple(self._params_torch.shape) if self._has_params else None + + self._compiled_fmha = _get_compiled_prefill_kernel( + self._in_dtype, + self._out_dtype, + num_qo_heads, + num_kv_heads, + self._head_dim, + self._mask_type, + window_left, + self._is_persistent, + cache_variant, + params_shape, + ) + + # Pre-allocate padded output scratch buffer. The kernel uses a + # negative pointer offset into the output tensor for TMA varlen + # addressing (see prefill.py __call__, "markus's trick"), so the + # buffer needs max_s_q extra rows in front. Allocating once here + # avoids per-run() allocation overhead across all layers. + _torch_out_dtype_map = { + torch.float16: torch.float16, + torch.bfloat16: torch.bfloat16, + torch.float8_e4m3fn: torch.float16, + } + torch_out_dtype = _torch_out_dtype_map[q_data_type] + self._o_scratch = torch.empty( + (self._o_padding + s_q_all, num_qo_heads, self._head_dim), + dtype=torch_out_dtype, + device=self._device, + ) + self._o_scratch_view = self._o_scratch[self._o_padding :] + + def _validate_run_inputs( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + ) -> None: + """Check that run() inputs are consistent with the plan() configuration.""" + for name, tensor in [("q", q), ("k", k), ("v", v)]: + if tensor.dtype != self._q_data_type: + raise ValueError( + f"{name}.dtype={tensor.dtype} does not match the planned " + f"q_data_type={self._q_data_type}" + ) + if tensor.device != self._device: + raise ValueError( + f"{name}.device={tensor.device} does not match the planned " + f"device={self._device}" + ) + if q.shape[-1] != self._head_dim: + raise ValueError( + f"q.shape[-1]={q.shape[-1]} does not match the planned " + f"head_dim={self._head_dim}" + ) + if q.shape[-2] != self._num_qo_heads: + raise ValueError( + f"q.shape[-2]={q.shape[-2]} does not match the planned " + f"num_qo_heads={self._num_qo_heads}" + ) + if k.shape[-2] != self._num_kv_heads: + raise ValueError( + f"k.shape[-2]={k.shape[-2]} does not match the planned " + f"num_kv_heads={self._num_kv_heads}" + ) + if out is not None: + if out.device != self._device: + raise ValueError( + f"out.device={out.device} does not match the planned " + f"device={self._device}" + ) + + @flashinfer_api + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""Run the prefill attention computation. + + Parameters + ---------- + q : torch.Tensor + The query tensor with shape [total_q_len, num_heads, head_dim]. + k : torch.Tensor + The key tensor with shape [total_kv_len, num_heads, head_dim]. + v : torch.Tensor + The value tensor with shape [total_kv_len, num_heads, head_dim]. + out : Optional[torch.Tensor], optional + The output tensor. If None, a new tensor will be created. + + Returns + ------- + torch.Tensor + The output tensor with shape [total_q_len, num_heads, head_dim]. + """ + if self._compiled_fmha is None: + raise RuntimeError("Plan the prefill attention computation first!") + + self._validate_run_inputs(q, k, v, out) + + self._compiled_fmha( + q, + k, + v, + self._o_scratch_view, + self._problem_size, + self._qo_indptr, + self._s_q_all, + self._kv_indptr, + self._s_k_all, + self._scale_softmax_log2, + self._scale_output, + self._params_torch if self._has_params else None, + ) + + if out is not None: + out.copy_(self._o_scratch_view) + return out + return self._o_scratch_view.clone() diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index f2238a739c..4e8bdd7212 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -807,7 +807,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise RuntimeError( f"cute-dsl backend (MLA decode kernel) requires SM100+, got SM{cc[0]}{cc[1]}" ) - from .cute_dsl import cute_dsl_mla_decode + from flashinfer.cute_dsl.attention import cute_dsl_mla_decode if isinstance(bmm1_scale, torch.Tensor): raise ValueError( diff --git a/flashinfer/mla/cute_dsl/__init__.py b/flashinfer/mla/cute_dsl/__init__.py deleted file mode 100644 index 24572e9913..0000000000 --- a/flashinfer/mla/cute_dsl/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2026 by FlashInfer team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -CuTe DSL MLA Decode Kernels for Blackwell SM100. -""" - -from flashinfer.cute_dsl.utils import is_cute_dsl_available - -if is_cute_dsl_available(): - from .mla_decode import cute_dsl_mla_decode - -__all__ = [ - "is_cute_dsl_available", -] - -if is_cute_dsl_available(): - __all__ += [ - "cute_dsl_mla_decode", - ] diff --git a/flashinfer/mla/cute_dsl/mla_decode.py b/flashinfer/mla/cute_dsl/mla_decode.py deleted file mode 100644 index 1887e4e25c..0000000000 --- a/flashinfer/mla/cute_dsl/mla_decode.py +++ /dev/null @@ -1,504 +0,0 @@ -# Copyright (c) 2026 by FlashInfer team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -CuTe DSL MLA Decode Kernel Integration -======================================= - -Wraps NVIDIA's CuTe DSL MLA decode kernels (FP16/BF16/FP8) for Blackwell SM100 -and exposes them via a PyTorch API compatible with FlashInfer's MLA backend. -""" - -import functools -from typing import Callable, Optional, Tuple - -import cutlass -import cutlass.cute as cute -import torch -from cutlass import Float32, Int32 - -from ...utils import device_support_pdl - -from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 -from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 -from flashinfer.cute_dsl.utils import ( - get_max_active_clusters, - get_num_sm, - torch_to_cutlass_dtype, -) - - -@functools.cache -def _get_split_kv_and_workspace_size( - B: int, - q_len: int, - H: int, - kv_lora_rank: int, - max_active_blocks: int, -) -> Tuple[int, int]: - """Cache split_kv and workspace_size since they are deterministic for the same params.""" - split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified( - B, q_len, max_active_blocks - ) - workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( - H, q_len, kv_lora_rank, B, split_kv, cutlass.Float32 - ) - return split_kv, workspace_size - - -@functools.cache -def _check_can_implement( - torch_dtype: torch.dtype, - torch_out_dtype: torch.dtype, - page_size: int, - num_heads: int, - seq_len_q: int, - kv_lora_rank: int, - qk_rope_head_dim: int, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, -) -> None: - """Check if the kernel supports the given configuration (cached).""" - mma_qk_tiler_mn = (128, 128) - mma_pv_tiler_mn = (128, 256) - - is_fp8 = torch_dtype == torch.float8_e4m3fn - KernelClass = ( - BlackwellMultiHeadLatentAttentionForwardFP8 - if is_fp8 - else BlackwellMultiHeadLatentAttentionForwardFP16 - ) - cutlass_in_dtype = torch_to_cutlass_dtype(torch_dtype) - cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) - if not KernelClass.can_implement( - 1, # B (runtime, use placeholder) - seq_len_q, - 1, # K (runtime, use placeholder) - num_heads, - kv_lora_rank, - qk_rope_head_dim, - cutlass_in_dtype, - cutlass_out_dtype, - cutlass.Float32, - cutlass.Float32, - mma_qk_tiler_mn, - mma_pv_tiler_mn, - 1, # split_kv (runtime, use 1 to pass the H<128 check) - is_persistent, - is_var_seq, - is_var_split_kv, - page_size, - ): - raise ValueError( - f"cute_dsl_mla_decode: unsupported configuration " - f"(q_len={seq_len_q}, num_heads={num_heads}, page_size={page_size}, " - f"in_dtype={torch_dtype}, out_dtype={torch_out_dtype})" - ) - - -@functools.cache -def _get_compiled_mla_kernel( - torch_dtype: torch.dtype, - torch_out_dtype: torch.dtype, - page_size: int, - kv_lora_rank: int, - qk_rope_head_dim: int, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - skip_correction_threshold: float = 0.0, - is_workspace_size_zero: bool = False, - enable_pdl: bool = False, -) -> Callable: - """Compile and cache an MLA decode kernel. - - Returns a callable that accepts (q_latent, q_rope, c_latent, c_rope, - page_table, o, lse, workspace, split_kv_scalar, cache_seqs, - block_split_kvs, softmax_scale_scalar, output_scale_scalar). - - All scalar arguments must be pre-wrapped as Int32/Float32. - """ - # Tile sizes for Blackwell mma. - # (128, 128) for QK and (128, 256) for PV. - mma_qk_tiler_mn = (128, 128) - mma_pv_tiler_mn = (128, 256) - # 2 CTAs along M (num_heads) - cluster_shape_mnk = (2, 1, 1) - - is_fp8 = torch_dtype == torch.float8_e4m3fn - KernelClass = ( - BlackwellMultiHeadLatentAttentionForwardFP8 - if is_fp8 - else BlackwellMultiHeadLatentAttentionForwardFP16 - ) - cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) - cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) - - kernel_obj = KernelClass( - acc_dtype=cutlass.Float32, - lse_dtype=cutlass.Float32, - mma_qk_tiler_mn=mma_qk_tiler_mn, - mma_pv_tiler_mn=mma_pv_tiler_mn, - max_active_clusters=get_max_active_clusters( - cluster_shape_mnk[0] * cluster_shape_mnk[1] - ), - page_size=page_size, - skip_correction_threshold=skip_correction_threshold, - is_persistent=is_persistent, - is_var_seq=is_var_seq, - is_var_split_kv=is_var_split_kv, - enable_pdl=enable_pdl, - ) - - # All dimensions as sym_int — this matches the original kernel's use of - # mark_compact_shape_dynamic, which makes ALL shapes dynamic CuTe Integers. - # Static Python ints would cause cute.assume() to fail with AttributeError - # inside initialize_workspace() since it expects DSL Integer types. - sym_heads = cute.sym_int() - sym_latent = cute.sym_int(divisibility=16) - sym_seq_q = cute.sym_int() - sym_rope = cute.sym_int(divisibility=16) - sym_batch = cute.sym_int() # query/output batch dimension - sym_kv_batch = cute.sym_int() # KV cache batch dim (flat pool, =1 in paged mode) - sym_seq_kv = cute.sym_int() - sym_page_count = cute.sym_int() - sym_workspace_size = cute.sym_int() - - # q_latent, q_rope, c_latent, c_rope are slices of contiguous tensors on - # the last dim (e.g. query[..., :kv_lora_rank]), so they are NOT contiguous: - # stride[-2] = D_qk (original full last dim), not the sliced shape. - # Use make_fake_tensor with fully dynamic strides so the compiled kernel - # reads actual strides from the runtime tensor. Last-dim stride is always 1. - - # q_latent: [batch_size, seq_len_q, num_heads, latent_dim] — non-contiguous slice - q_latent_fake = cute.runtime.make_fake_tensor( - cutlass_dtype, - (sym_batch, sym_seq_q, sym_heads, sym_latent), - stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), - assumed_align=16, - ) - # q_rope: [batch_size, seq_len_q, num_heads, rope_dim] — non-contiguous slice - q_rope_fake = cute.runtime.make_fake_tensor( - cutlass_dtype, - (sym_batch, sym_seq_q, sym_heads, sym_rope), - stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), - assumed_align=16, - ) - # c_latent: [kv_batch, seq_len_k, latent_dim] — non-contiguous slice - # kv_batch is a separate sym_int from query batch: paged KV cache uses a flat - # pool so kv_batch=num_pages at runtime, while query batch can be any value. - c_latent_fake = cute.runtime.make_fake_tensor( - cutlass_dtype, - (sym_kv_batch, sym_seq_kv, sym_latent), - stride=(cute.sym_int(), cute.sym_int(), 1), - assumed_align=16, - ) - # c_rope: [kv_batch, seq_len_k, rope_dim] — non-contiguous slice - c_rope_fake = cute.runtime.make_fake_tensor( - cutlass_dtype, - (sym_kv_batch, sym_seq_kv, sym_rope), - stride=(cute.sym_int(), cute.sym_int(), 1), - assumed_align=16, - ) - # page_table: [batch_size, page_count] — contiguous - page_table_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (sym_batch, sym_page_count), - stride_order=(1, 0), - assumed_align=16, - ) - # o: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous - o_fake = cute.runtime.make_fake_compact_tensor( - cutlass_out_dtype, - (sym_batch, sym_seq_q, sym_heads, sym_latent), - stride_order=(3, 2, 1, 0), - assumed_align=16, - ) - # lse: [batch_size, seq_len_q, num_heads] — contiguous - lse_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Float32, - (sym_batch, sym_seq_q, sym_heads), - stride_order=(2, 1, 0), - assumed_align=16, - ) - if is_workspace_size_zero: - workspace_fake = None - else: - # workspace: 1-D int8 buffer. 32-byte alignment because workspace stores - # fp32 partial sums internally, requiring stricter alignment than tensors. - workspace_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Int8, - (sym_workspace_size,), - assumed_align=32, - ) - # cache_seqs: [batch_size] — int32 - cache_seqs_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (sym_batch,), - assumed_align=16, - ) - # block_split_kvs: [batch_size] — int32 (only needed for is_var_split_kv=True) - if is_var_split_kv: - block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (sym_batch,), - assumed_align=16, - ) - else: - block_split_kvs_fake = None - - stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - - compiled_kernel = cute.compile( - kernel_obj, - q_latent_fake, - q_rope_fake, - c_latent_fake, - c_rope_fake, - page_table_fake, - o_fake, - lse_fake, - workspace_fake, - Int32(1), # split_kv placeholder - cache_seqs_fake, - block_split_kvs_fake, - Float32(1.0), # softmax_scale placeholder - Float32(1.0), # output_scale placeholder - stream_fake, - options="--enable-tvm-ffi --opt-level 2", - ) - - return compiled_kernel - - -# TODO: query[..., :kv_lora_rank], do we need to remove such kind of slice and move the logic to call routine in the kernel file. -def cute_dsl_mla_decode( - query: torch.Tensor, - kv_cache: torch.Tensor, - workspace_buffer: torch.Tensor, - kv_lora_rank: int, - qk_rope_head_dim: int, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - softmax_scale: float, - output_scale: float = 1.0, - out: Optional[torch.Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - is_var_seq: bool = True, - enable_pdl: Optional[bool] = None, -) -> torch.Tensor: - """CuTe DSL MLA decode kernel for Blackwell SM100. - - Parameters - ---------- - query : torch.Tensor - [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim - kv_cache : torch.Tensor - [num_pages, page_size, D_ckv + D_kpe] (3D) or [num_pages, 1, page_size, D_ckv + D_kpe] (4D) - workspace_buffer : torch.Tensor - Pre-allocated workspace buffer (uint8). Required size depends on batch size - and split_kv (auto-computed from B, q_len, and number of SMs): - - - Formula: ``B * H * q_len * split_kv * (kv_lora_rank + 1) * 4`` bytes - (0 when split_kv == 1, which happens when B >= num_SMs / 2) - - Typical max: ~18 MB on a 148-SM GPU (e.g. B=4..8, H=128, D=512) - - Safe default: 128 MB covers all realistic configurations - kv_lora_rank : int - Latent dimension (e.g. 512). - qk_rope_head_dim : int - RoPE dimension (e.g. 64). - block_tables : torch.Tensor - [B, max_pages] — page table indices. - seq_lens : torch.Tensor - [B] — per-request KV sequence lengths. - max_seq_len : int - Maximum sequence length across the batch. - softmax_scale : float - Scale factor for QK^T before softmax. - output_scale : float - Scale factor applied to the output. - out : Optional[torch.Tensor] - Pre-allocated output tensor [B, q_len, H, kv_lora_rank]. - out_dtype : Optional[torch.dtype] - Output data type. If None, defaults to torch.bfloat16 (matching trtllm-gen backend). - Supported values: torch.bfloat16, torch.float8_e4m3fn (FP8 input only), - torch.float16, torch.bfloat16 (FP16/BF16 input). - is_var_seq : bool - Whether the sequence length is variable. - If True, the sequence length is variable. - Otherwise,the sequence length is fixed for all the requests in the batch. - enable_pdl : Optional[bool], default=None - Whether to enable Programmatic Dependent Launch (PDL). - If None, auto-detects based on device capability. - - Returns - ------- - torch.Tensor - Output tensor [B, q_len, H, kv_lora_rank]. - """ - supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} - assert query.dtype in supported_dtypes, ( - f"cute_dsl_mla_decode only supports {supported_dtypes}, got {query.dtype}" - ) - assert kv_cache.dtype == query.dtype, ( - f"kv_cache dtype {kv_cache.dtype} must match query dtype {query.dtype}" - ) - B, q_len, H, D_qk = query.shape - assert D_qk == kv_lora_rank + qk_rope_head_dim - - q_dtype = query.dtype - # Resolve output dtype: for FP8 input, default to bfloat16 (matching trtllm-gen backend); - # for FP16/BF16 input, default to same as input. Allow override via out_dtype or out tensor. - if out is not None: - o_dtype = out.dtype - elif out_dtype is not None: - o_dtype = out_dtype - elif q_dtype == torch.float8_e4m3fn: - o_dtype = torch.bfloat16 - else: - o_dtype = q_dtype - - # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] - if kv_cache.dim() == 4: - kv_cache = kv_cache.squeeze(1) - page_size = kv_cache.shape[1] - - # Split query into latent and rope components — keep contiguous [B, q_len, H, D]. - # The kernel's __call__ reinterprets to [H, D, q_len, B] via zero-cost make_tensor. - q_latent_k = query[..., :kv_lora_rank] - q_rope_k = query[..., kv_lora_rank:] - - # KV cache slices — keep contiguous [num_pages, page_size, D]. - # The kernel reinterprets to [page_size, D, num_pages] internally. - c_latent_k = kv_cache[:, :, :kv_lora_rank] - c_rope_k = kv_cache[:, :, kv_lora_rank:] - - # Page table: [B, max_pages]: passed directly, kernel reinterprets. - page_table_k = block_tables - - # Runtime validation (int comparisons only, negligible overhead) - if max_seq_len <= 0: - raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") - # H=128: standard DeepSeek-V3 MLA config; H=1: used by split-kv reduction path. - # Values 2..127 are not supported by the kernel's tile config. - if H < 128 and H != 1: - raise ValueError( - f"cute_dsl_mla_decode requires num_heads >= 128 (or 1 for reduction), got {H}" - ) - - # Cached split_kv and workspace_size computation - max_active_blocks = get_num_sm(query.device) - split_kv, workspace_size = _get_split_kv_and_workspace_size( - B, q_len, H, kv_lora_rank, max_active_blocks - ) - - if H < 128 and split_kv != 1: - raise ValueError( - f"cute_dsl_mla_decode: num_heads={H} < 128 requires split_kv==1, " - f"got split_kv={split_kv}" - ) - - # Prepare workspace: slice of contiguous 1D buffer is already contiguous - assert workspace_buffer.dtype == torch.int8, ( - f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" - ) - assert workspace_buffer.numel() >= workspace_size, ( - f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " - f"need {workspace_size} bytes" - ) - is_workspace_size_zero = workspace_size == 0 - if is_workspace_size_zero: - workspace_bytes = None - else: - workspace_bytes = workspace_buffer[:workspace_size] - # Output buffer: contiguous [B, q_len, H, D]. - # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. - if out is not None: - o_k = out - else: - o_k = torch.empty( - (B, q_len, H, kv_lora_rank), dtype=o_dtype, device=query.device - ) - - # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. - lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) - - # cache_seqs: per-batch sequence lengths (skip .to() if already int32) - cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) - - is_var_split_kv = False - block_split_kvs = None - skip_correction_threshold = 0.0 - - # for fix-length, set is_persistent to True; otherwise, set to False. - is_persistent = not is_var_seq - - # Validate configuration (cached, negligible overhead after first call) - _check_can_implement( - torch_dtype=q_dtype, - torch_out_dtype=o_dtype, - page_size=page_size, - num_heads=H, - seq_len_q=q_len, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - is_persistent=is_persistent, - is_var_seq=is_var_seq, - is_var_split_kv=is_var_split_kv, - ) - - enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - - # Get compiled kernel (cached after first compile) - # Note: when is_workspace_size_zero is True, workspace_bytes is None and it will launch one kernel without workspace. - # Otherwise, workspace_bytes is not None and it will launch two kernels. - compiled_kernel = _get_compiled_mla_kernel( - torch_dtype=q_dtype, - torch_out_dtype=o_dtype, - page_size=page_size, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - is_persistent=is_persistent, - is_var_seq=is_var_seq, - is_var_split_kv=is_var_split_kv, - skip_correction_threshold=skip_correction_threshold, - is_workspace_size_zero=is_workspace_size_zero, - enable_pdl=enable_pdl, - ) - - # Call the kernel - compiled_kernel( - q_latent_k, - q_rope_k, - c_latent_k, - c_rope_k, - page_table_k, - o_k, - lse_k, - workspace_bytes, - Int32(split_kv), - cache_seqs, - block_split_kvs, - Float32(softmax_scale), - Float32(output_scale), - ) - - # If out was provided, kernel already wrote into it — return directly. - if out is not None: - return out - - # o_k is [B, q_len, H, D] — return as-is to match trtllm-gen output shape. - return o_k diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py deleted file mode 100644 index df18a414fe..0000000000 --- a/flashinfer/mla/cute_dsl/mla_decode_fp16.py +++ /dev/null @@ -1,4250 +0,0 @@ -# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -import math -from typing import Type, Tuple, Optional -from types import SimpleNamespace - -import torch -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute -import cutlass.cute.testing as testing -import cutlass.cute.nvgpu.tcgen05 as tcgen05 - -# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.3.x is no longer supported. -# Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; -# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. -_setmaxregister_decrease = getattr( - cute.arch, - "setmaxregister_decrease", - getattr(cute.arch, "warpgroup_reg_dealloc", None), -) -_setmaxregister_increase = getattr( - cute.arch, - "setmaxregister_increase", - getattr(cute.arch, "warpgroup_reg_alloc", None), -) - -# Compat shim: get_max_tmem_alloc_cols added in cutlass-dsl 4.4; -# older versions don't have it, so we provide a fallback implementation. -_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} - - -# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.3.x is no longer supported. -def _get_max_tmem_alloc_cols(compute_capability: str) -> int: - if hasattr(cute.arch, "get_max_tmem_alloc_cols"): - return cute.arch.get_max_tmem_alloc_cols(compute_capability) - if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: - raise ValueError(f"Unsupported compute capability: {compute_capability}") - return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] - - -from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode -import cutlass.cute.nvgpu.cpasync as cpasync -import cutlass.utils as utils -import cutlass.pipeline as pipeline -from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait -import cutlass.torch as cutlass_torch -import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cute.runtime import from_dlpack -from cutlass.base_dsl.arch import Arch -from cutlass.cutlass_dsl import BaseDSL - - -from .mla_helpers import ( - ceil_div, - MAX_SPLITS, - LOG2_E, - MLAStaticTileScheduler, - MLAStaticTileSchedulerParams, - create_mla_static_tile_scheduler, - create_mla_static_tile_scheduler_params, -) - -""" -A Multi-Head Latent Attention (MLA) example with FP16 data type for the NVIDIA Blackwell SM100 architecture using CUTE DSL - -This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell -SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T -matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. -The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting -functionality to minimize latency when processing long KV sequences. - -The kernel implements key optimizations including: -- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) -- Pipeline stages between different warps for overlapping computation and memory access -- Support for different precision data types -- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing - -To run this example: - -.. code-block:: bash - - python examples/blackwell/mla_fp16.py \ - --batch_size 4 --latent_dim 512 --rope_dim 64 \ - --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ - --in_dtype Float16 --out_dtype Float16 \ - --acc_dtype Float32 --lse_dtype Float32 \ - --is_var_seq --is_var_split_kv \ - --is_persistent - -The above example runs Multi-Head Latent Attention (MLA) with the following configuration: -- Batch size: 4 -- Sequence length of Q: 1 -- Sequence length of K: 1024 -- Latent dimension: 512 -- RoPE dimension: 64 -- Number of heads: 128 -- Data types: Float16 (input), Float16 (output), Float32 (accumulation and LSE) - -It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences -and variable split KV processing with persistent scheduling. - -To collect performance with NCU profiler: - -.. code-block:: bash - - ncu python examples/blackwell/mla_fp16.py \ - --batch_size 4 --latent_dim 512 --rope_dim 64 \ - --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ - --in_dtype Float16 --out_dtype Float16 \ - --acc_dtype Float32 --lse_dtype Float32 \ - --is_var_seq --is_var_split_kv \ - --is_persistent --warmup_iterations 3 \ - --iterations 10 --skip_ref_check - -Constraints for this example: -* Data type requirements: - - Input/output: Float16 - - Accumulation and LSE: Float32 -* Fixed architecture parameters: - - Number of attention heads: 128 - - Latent dimension: 512 - - RoPE dimension: 64 -* Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) -* Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) -* Query sequence length must be 1-4 -* Only supports 2-CTA instructions -* Variable sequence length requires page table storage enabled -""" - - -class BlackwellMultiHeadLatentAttentionForwardFP16: - def __init__( - self, - acc_dtype: Type[cutlass.Numeric], - lse_dtype: Type[cutlass.Numeric], - mma_qk_tiler_mn: Tuple[int, int], - mma_pv_tiler_mn: Tuple[int, int], - max_active_clusters: int, - page_size: int, - skip_correction_threshold: float, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - enable_pdl: bool, - ): - """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. - - :param acc_dtype: Data type for accumulation S and O - :type acc_dtype: Type[cutlass.Numeric] - :param lse_dtype: Data type for output LSE - :type lse_dtype: Type[cutlass.Numeric] - :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S - :type mma_s_tiler: Tuple[int, int] - :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P - :type mma_p_tiler: Tuple[int, int] - :param max_active_clusters: Maximum number of active clusters - :type max_active_clusters: int - :param page_size: The page size of the page table - :type page_size: int - :param skip_correction_threshold: Threshold to skip correction - :type skip_correction_threshold: float - :param is_persistent: Whether to use persistent kernel mode - :type is_persistent: bool - :param is_var_seq: Whether to use variable sequence length - :type is_var_seq: bool - :param is_var_split_kv: Whether to use variable split KV - :type is_var_split_kv: bool - :param enable_pdl: Whether to use PDL - :type enable_pdl: bool - """ - - self.latent_dim = 512 - self.rope_dim = 64 - self.acc_dtype = acc_dtype - self.lse_dtype = lse_dtype - self.mma_qk_tiler_mn = mma_qk_tiler_mn - self.mma_pv_tiler_mn = mma_pv_tiler_mn - self.max_active_clusters = max_active_clusters - self.skip_correction_threshold = skip_correction_threshold - self.is_persistent = is_persistent - self.page_size = page_size - self.is_var_seq = is_var_seq - self.is_var_split_kv = is_var_split_kv - self.enable_pdl = enable_pdl - self.cluster_shape_mnk = (2, 1, 1) - self.use_2cta_instrs = True - # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), - # while warps 2-3 handle accumulation for second half [n/2, n) - self.warps_in_n = 2 - self.num_compute_warps = 4 - self.threads_per_warp = 32 - mma_qk_tiler_k = self.rope_dim - self.mma_qk_tiler = ( - self.mma_qk_tiler_mn[0], - self.mma_qk_tiler_mn[1], - mma_qk_tiler_k, - ) - self.mma_qk_rope_tiler = ( - self.mma_qk_tiler_mn[0], - self.mma_qk_tiler_mn[1], - self.rope_dim, - ) - self.mma_pv_tiler = ( - self.mma_pv_tiler_mn[0], - self.mma_pv_tiler_mn[1], - self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], - ) - self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] - self.iterations_qk_rope = mma_qk_tiler_k // self.mma_qk_tiler[2] - self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope - self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] - self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] - - # Set specialized warp ids - self.compute_warp_ids = (0, 1, 2, 3) - self.correction_warp_ids = (4, 5, 6, 7) - self.mma_warp_id = 8 - - self.load_tma_warp_id = 9 - self.load_pt_warp_id = 10 - self.empty_warp_ids = (11,) - self.threads_per_cta = self.threads_per_warp * len( - ( - self.mma_warp_id, - self.load_tma_warp_id, - self.load_pt_warp_id, - *self.compute_warp_ids, - *self.correction_warp_ids, - *self.empty_warp_ids, - ) - ) - - # register settings - self.softmax_reg_num = 192 - self.correction_reg_num = 208 - self.other_reg_num = 96 - # Named barriers - self.tmem_ptr_sync_bar = pipeline.NamedBarrier( - barrier_id=1, - num_threads=( - self.threads_per_warp - + self.threads_per_warp * self.num_compute_warps * 2 - ), - ) - self.softmax_exchange_sync_bar = pipeline.NamedBarrier( - barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) - ) - self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( - barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) - ) - - def _setup_attributes(self): - """Set up configurations and parameters for the MLA kernel operation. - - This method initializes and configures various attributes required for the - execution of the multi-head latent attention kernel, mainly about the pipeline stages: - - - Sets up staging parameters for Q, K, V inputs and accumulator data - - Configures pipeline stages for softmax, correction, and epilogue operations - """ - - self.load_q_stage = 1 - self.load_kv_stage = 15 - self.mma_s_stage = 2 - self.p_mma_stage = 2 - self.p_cor_stage = 2 - self.mma_o_stage = 1 - self.load_pt_stage = 4 - - self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n - self.correction_factor_offset = ( - self.tmem_o_offset + self.latent_dim // self.warps_in_n - ) - - @cute.jit - def __call__( - self, - q_latent: cute.Tensor, - q_rope: cute.Tensor, - c_latent: cute.Tensor, - c_rope: cute.Tensor, - page_table: cute.Tensor, - o: cute.Tensor, - lse: cute.Tensor, - workspace: cute.Tensor, - split_kv: cutlass.Int32, - cache_seqs: Optional[cute.Tensor], - block_split_kvs: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, - output_scale: cutlass.Float32, - stream: cuda.CUstream, - ): - """Execute the Multi-Head Latent Attention operation on the provided tensors. - - The method handles: - 1. Initialization of workspace for temporary split KV buffers - 2. Validation of tensor data types - 3. Initialization of hardware-specific parameters and memory layouts - 4. Configuration of TMA (Tensor Memory Access) operations - 5. Grid and work scheduling computation - 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters - - :param q_latent: The query tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) - :type q_latent: cute.Tensor - :param q_rope: The query RoPE tensor with shape [batch_size, seq_len_q, num_head, rope_dim] (contiguous) - :type q_rope: cute.Tensor - :param c_latent: The key tensor with shape [num_pages, page_size, latent_dim] (contiguous) - :type c_latent: cute.Tensor - :param c_rope: The key RoPE tensor with shape [num_pages, page_size, rope_dim] (contiguous) - :type c_rope: cute.Tensor - :param page_table: The page table tensor with shape [batch_size, page_count] (contiguous) - :type page_table: cute.Tensor - :param o: The output tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) - :type o: cute.Tensor - :param lse: The LSE tensor with shape [batch_size, seq_len_q, num_head] (contiguous) - :type lse: cute.Tensor - :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse - :type workspace: cute.Tensor - :param split_kv: The scalar factor for split KV - :type split_kv: cutlass.Int32 - :param cache_seqs: The cache sequences tensor with shape [batch_size] - :type cache_seqs: cute.Tensor - :param block_split_kvs: The block split KV tensor with shape [batch_size] - :type block_split_kvs: cute.Tensor - :param softmax_scale: The scale factor for softmax - :type softmax_scale: cutlass.Float32 - :param output_scale: The scale factor for the output - :type output_scale: cutlass.Float32 - :param stream: The CUDA stream to execute the kernel on - :type stream: cuda.CUstream - - :raises TypeError: If tensor data types don't match or aren't supported - """ - - # setup static attributes before smem/grid/tma computation - self.q_dtype = q_latent.element_type - self.k_dtype = c_latent.element_type - self.v_dtype = c_latent.element_type - self.o_dtype = o.element_type - - # check type consistency - if cutlass.const_expr( - self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype - ): - raise TypeError( - f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" - ) - - # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] - # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) - def _reinterpret_4d(t): - return cute.make_tensor( - t.iterator, - cute.make_layout( - (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), - stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), - ), - ) - - q_latent = _reinterpret_4d(q_latent) - q_rope = _reinterpret_4d(q_rope) - o = _reinterpret_4d(o) - - # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] - # Input stride: (PS*D, D, 1) → Target: (D, 1, PS*D) - def _reinterpret_3d_kv(t): - return cute.make_tensor( - t.iterator, - cute.make_layout( - (t.shape[1], t.shape[2], t.shape[0]), - stride=(t.stride[1], t.stride[2], t.stride[0]), - ), - ) - - c_latent = _reinterpret_3d_kv(c_latent) - c_rope = _reinterpret_3d_kv(c_rope) - - # Reinterpret contiguous [B, page_count] as [page_count, B] - page_table = cute.make_tensor( - page_table.iterator, - cute.make_layout( - (page_table.shape[1], page_table.shape[0]), - stride=(page_table.stride[1], page_table.stride[0]), - ), - ) - - # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] - # Input stride: (S_q*H, H, 1) → Target: (1, H, S_q*H) - lse = cute.make_tensor( - lse.iterator, - cute.make_layout( - (lse.shape[2], lse.shape[1], lse.shape[0]), - stride=(lse.stride[2], lse.stride[1], lse.stride[0]), - ), - ) - - acc_o, acc_lse = self.initialize_workspace( - q_latent.shape[0], - q_latent.shape[1], - q_latent.shape[2], - q_latent.shape[3], - split_kv, - self.acc_dtype, - workspace, - ) - - c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) - c_latent_transpose = cute.make_tensor( - c_latent.iterator, c_latent_tranpose_layout - ) - - self.q_major_mode = tcgen05.OperandMajorMode.K - self.k_major_mode = tcgen05.OperandMajorMode.K - self.v_major_mode = tcgen05.OperandMajorMode.MN - - self._setup_attributes() - - cta_group = tcgen05.CtaGroup.TWO - # the intermediate tensor p is from smem & k-major - p_major_mode = tcgen05.OperandMajorMode.K - qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( - self.q_dtype, - self.q_major_mode, - self.k_major_mode, - self.acc_dtype, - cta_group, - self.mma_qk_tiler[:2], - ) - pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( - self.v_dtype, - p_major_mode, - self.v_major_mode, - self.acc_dtype, - cta_group, - self.mma_pv_tiler[:2], - ) - - cta_layout_vmnk = cute.tiled_divide( - cute.make_layout(self.cluster_shape_mnk), - (qk_tiled_mma.thr_id.shape,), - ) - - self.epi_tile = self.mma_pv_tiler[:2] - - q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( - qk_tiled_mma, - self.mma_qk_tiler, - self.q_dtype, - (self.iterations_qk_latent * self.load_q_stage), - ) - q_latent_smem_layout_staged = cute.logical_divide( - q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) - ) - q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( - qk_tiled_mma, - self.mma_qk_rope_tiler, - self.q_dtype, - self.load_q_stage, - ) - - # rope reuse the same smem layout as latent - kc_smem_layout_staged = sm100_utils.make_smem_layout_b( - qk_tiled_mma, - self.mma_qk_tiler, - self.k_dtype, - self.load_kv_stage, - ) - kc_page_tile_size = min( - self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape - ) - - kc_smem_layout_for_tma = sm100_utils.make_smem_layout( - OperandMajorMode.K, - (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), - self.k_dtype, - self.load_kv_stage, - ) - kc_smem_layout_for_tma = cute.tiled_divide( - kc_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) - ) - - p_smem_layout_staged = sm100_utils.make_smem_layout_a( - pv_tiled_mma, - self.mma_pv_tiler, - self.q_dtype, - (self.iterations_pv_k * self.p_mma_stage), - ) - p_smem_layout_staged = cute.logical_divide( - p_smem_layout_staged, (None, None, None, self.iterations_pv_k) - ) - - vc_smem_layout_staged = sm100_utils.make_smem_layout_b( - pv_tiled_mma, - self.mma_pv_tiler, - self.v_dtype, - self.load_kv_stage, - ) - vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) - vc_smem_layout_for_tma = sm100_utils.make_smem_layout( - OperandMajorMode.MN, - (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), - self.v_dtype, - self.load_kv_stage, - ) - vc_smem_layout_for_tma = cute.tiled_divide( - vc_smem_layout_for_tma, - ( - pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, - vc_page_tile_size, - ), - ) - # TMA load for Q latent and rope - tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) - - q_latent_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - q_latent, - q_latent_smem_layout, - self.mma_qk_tiler, - qk_tiled_mma, - cta_layout_vmnk.shape, - ) - q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - q_rope, - q_rope_smem_layout, - self.mma_qk_rope_tiler, - qk_tiled_mma, - cta_layout_vmnk.shape, - ) - # TMA load for c latent and k rope - kc_smem_layout = cute.select(kc_smem_layout_for_tma, mode=[0]) - tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( - tma_load_op, - c_latent, - kc_smem_layout, - (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), - qk_tiled_mma, - is_k_load=True, - ) - tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( - tma_load_op, - c_rope, - kc_smem_layout, - (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), - qk_tiled_mma, - is_k_load=True, - ) - # TMA load for c latent transpose - vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) - tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( - self.make_paged_tiled_tma_atom( - tma_load_op, - c_latent_transpose, - vc_smem_layout, - (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), - pv_tiled_mma, - is_k_load=False, - ) - ) - - q_latent_copy_size = ( - cute.size_in_bytes(self.q_dtype, q_latent_smem_layout) - * cute.size(qk_tiled_mma.thr_id.shape) - * self.iterations_qk_latent - ) - q_rope_copy_size = ( - cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) - * cute.size(qk_tiled_mma.thr_id.shape) - * self.iterations_qk_rope - ) - q_copy_size = q_latent_copy_size + q_rope_copy_size - kc_copy_size = cute.size_in_bytes( - self.k_dtype, cute.select(kc_smem_layout_staged, mode=[0, 1, 2]) - ) * cute.size(qk_tiled_mma.thr_id.shape) - vc_copy_size = cute.size_in_bytes( - self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) - ) * cute.size(pv_tiled_mma.thr_id.shape) - assert kc_copy_size == vc_copy_size, ( - "kc_copy_size and vc_copy_size must be the same" - ) - - self.tma_copy_q_bytes = q_copy_size - self.tma_copy_kc_bytes = kc_copy_size - - tile_sched_params, grid = self._compute_grid( - o, - split_kv, - self.cluster_shape_mnk, - self.max_active_clusters, - self.is_persistent, - ) - - @cute.struct - class SplitKVKernelSharedStorage: - # Pipeline barriers - load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] - load_kv_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.load_kv_stage * 2 - ] - mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] - p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] - p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] - mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] - load_pt_mbar_ptr: cute.struct.MemRange[ - cutlass.Int64, self.load_pt_stage * 2 - ] - # Tmem dealloc cluster barrier - tmem_dealloc_mbar_ptr: cutlass.Int64 - - # Tmem holding buffer - tmem_holding_buf: cutlass.Int32 - # Smem tensors - softmax_smem_exchange: cute.struct.MemRange[ - self.acc_dtype, self.num_compute_warps * self.threads_per_warp - ] - epilogue_smem_exchange: cute.struct.MemRange[ - self.acc_dtype, self.num_compute_warps * self.threads_per_warp - ] - smem_q_latent: cute.struct.Align[ - cute.struct.MemRange[ - self.q_dtype, cute.cosize(q_latent_smem_layout_staged) - ], - 1024, - ] - smem_q_rope: cute.struct.Align[ - cute.struct.MemRange[ - self.q_dtype, cute.cosize(q_rope_smem_layout_staged) - ], - 1024, - ] - smem_kc: cute.struct.Align[ - cute.struct.MemRange[self.k_dtype, cute.cosize(kc_smem_layout_staged)], - 1024, - ] - smem_p: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], - 1024, - ] - smem_page_table: cute.struct.MemRange[ - cutlass.Int32, self.load_pt_stage * self.mma_qk_tiler[1] // 2 - ] - - softmax_scale_log2 = softmax_scale * LOG2_E - self.split_kv_kernel( - qk_tiled_mma, - pv_tiled_mma, - tma_atom_q_latent, - tma_tensor_q_latent, - tma_atom_q_rope, - tma_tensor_q_rope, - tma_atom_c_latent, - tma_tensor_c_latent, - tma_atom_c_rope, - tma_tensor_c_rope, - tma_atom_c_latent_transpose, - tma_tensor_c_latent_transpose, - page_table, - o, - lse, - acc_o, - acc_lse, - split_kv, - cache_seqs, - block_split_kvs, - softmax_scale_log2, - output_scale, - q_latent_smem_layout_staged, - q_rope_smem_layout_staged, - kc_smem_layout_staged, - p_smem_layout_staged, - vc_smem_layout_staged, - kc_smem_layout_for_tma, - vc_smem_layout_for_tma, - cta_layout_vmnk, - tile_sched_params, - SplitKVKernelSharedStorage, - ).launch( - grid=grid, - block=[self.threads_per_cta, 1, 1], - cluster=self.cluster_shape_mnk, - smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] - stream=stream, - min_blocks_per_mp=1, - use_pdl=self.enable_pdl, - ) - if cutlass.const_expr(acc_o is not None): - self.reduction_kernel( - o, - lse, - acc_o, - acc_lse, - split_kv, - cache_seqs, - block_split_kvs, - ).launch( - grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), - block=[self.threads_per_warp * self.num_compute_warps, 1, 1], - smem=MAX_SPLITS * self.acc_dtype.width // 8, - stream=stream, - min_blocks_per_mp=1, - use_pdl=self.enable_pdl, - ) - - @cute.jit - def make_paged_tiled_tma_atom( - self, - tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, - gmem: cute.Tensor, - smem_layout: cute.Layout, - mma_tiler, - tiled_mma: cute.TiledMma, - is_k_load: bool, - ): - ident = cute.make_identity_layout(gmem.shape) - g_tile = cute.composition(ident, mma_tiler) - cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape - cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) - cta_v_map = cute.select(cta_v_map, mode=[0, 2]) - page_tile_size = ( - min(self.page_size, cta_mn) - if is_k_load - else min(self.page_size, mma_tiler[1]) - ) - cta_v_map = cute.zipped_divide( - cta_v_map, - (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), - ) - cta_v_map = cute.select(cta_v_map, mode=[0]) - from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir - - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem.value, - smem_layout.value, - cta_v_map, - tma_load_op._to_ir(), - num_multicast=1, - ) - return cute.CopyAtom( - tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) - ), res[1] - - @cute.kernel - def split_kv_kernel( - self, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - tma_atom_q_latent: Optional[cute.CopyAtom], - mQL: cute.Tensor, - tma_atom_q_rope: Optional[cute.CopyAtom], - mQR: cute.Tensor, - tma_atom_c_latent: Optional[cute.CopyAtom], - mCL: cute.Tensor, - tma_atom_c_rope: Optional[cute.CopyAtom], - mKR: cute.Tensor, - tma_atom_c_latent_transpose: Optional[cute.CopyAtom], - mCLT: cute.Tensor, - mPT: cute.Tensor, - mO: Optional[cute.Tensor], - mLSE: Optional[cute.Tensor], - mAccO: Optional[cute.Tensor], - mAccLSE: Optional[cute.Tensor], - split_kv: cutlass.Int32, - cache_seqs: cute.Tensor, - block_split_kvs: cute.Tensor, - softmax_scale_log2: cutlass.Float32, - output_scale: cutlass.Float32, - q_latent_smem_layout_staged: cute.ComposedLayout, - q_rope_smem_layout_staged: cute.ComposedLayout, - kc_smem_layout_staged: cute.ComposedLayout, - p_smem_layout_staged: cute.ComposedLayout, - vc_smem_layout_staged: cute.ComposedLayout, - kc_smem_layout_for_tma: cute.ComposedLayout, - vc_smem_layout_for_tma: cute.ComposedLayout, - cta_layout_vmnk: cute.Layout, - tile_sched_params: MLAStaticTileSchedulerParams, - SharedStorage: cutlass.Constexpr, - ): - """The device split_kv kernel implementation of the Multi-Head Latent Attention. - - This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: - 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA - 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) - 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results - to global memory - - The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. - When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results - that will later be combined by a reduction kernel. - - The kernel implements a complex pipeline with overlapping computation and memory operations, - using tensor memory access (TMA) for efficient data loading, warp specialization for different - computation phases. - - :param tiled_mma_qk: Tiled MMA for Q*K^T - :type tiled_mma_qk: cute.TiledMma - :param tiled_mma_pv: Tiled MMA for P*V - :type tiled_mma_pv: cute.TiledMma - :param tma_atom_q_latent: TMA copy atom for query latent tensor - :type tma_atom_q_latent: cute.CopyAtom - :param mQL: query latent tensor - :type mQL: cute.Tensor - :param tma_atom_q_rope: TMA copy atom for query rope tensor - :type tma_atom_q_rope: cute.CopyAtom - :param mKR: Compressed rope tensor - :type mKR: cute.Tensor - :param tma_atom_c_latent: TMA copy atom for c latent tensor - :type tma_atom_c_latent: cute.CopyAtom - :param mCL: Compressed latent tensor - :type mCL: cute.Tensor - :param tma_atom_c_rope: TMA copy atom for c rope tensor - :type tma_atom_c_rope: cute.CopyAtom - :param mCLT: Compressed latent transpose tensor - :type mCLT: cute.Tensor - :param mPT: Page table tensor - :type mPT: cute.Tensor - :param mO: Output tensor - :type mO: cute.Tensor - :param mLSE: Log-sum-exp tensor - :type mLSE: cute.Tensor - :param mAccO: Intermediate accumulator output tensor - :type mAccO: cute.Tensor - :param mAccLSE: Intermediate accumulator log-sum-exp tensor - :type mAccLSE: cute.Tensor - :param split_kv: The split_kv parameter - :type split_kv: cutlass.Int32 - :param cache_seqs: The variable sequence length tensor - :type cache_seqs: cute.Tensor - :param block_split_kvs: The per-block split_kv values tensor - :type block_split_kvs: cute.Tensor - :param softmax_scale_log2: The log2 scale factor for softmax - :type softmax_scale_log2: cutlass.Float32 - :param output_scale: The scale factor for the output - :type output_scale: cutlass.Float32 - :param q_latent_smem_layout_staged: Shared memory layout for query latent tensor - :type q_latent_smem_layout_staged: cute.ComposedLayout - :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor - :type q_rope_smem_layout_staged: cute.ComposedLayout - :param kc_smem_layout_staged: Shared memory layout for key/value latent/rope tensor - :type kc_smem_layout_staged: cute.ComposedLayout - :param p_smem_layout_staged: Shared memory layout for probability matrix - :type p_smem_layout_staged: cute.ComposedLayout - :param vc_smem_layout_staged: Shared memory layout for value tensor - :type vc_smem_layout_staged: cute.ComposedLayout - :param kc_smem_layout_for_tma: Shared memory layout for key/value latent tensor for TMA - :type kc_smem_layout_for_tma: cute.ComposedLayout - :param vc_smem_layout_for_tma: Shared memory layout for value tensor for TMA - :type vc_smem_layout_for_tma: cute.ComposedLayout - :param cta_layout_vmnk: Layout for compute threads - :type cta_layout_vmnk: cute.Layout - :param tile_sched_params: Scheduling parameters for work distribution - :type tile_sched_params: MLAStaticTileSchedulerParams - :param SharedStorage: Shared storage for the kernel - :type SharedStorage: cutlass.Constexpr - """ - - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) - is_leader_cta = mma_tile_coord_v == 0 - - # Prefetch tma descriptor - if warp_idx == self.mma_warp_id: - cpasync.prefetch_descriptor(tma_atom_q_latent) - cpasync.prefetch_descriptor(tma_atom_q_rope) - cpasync.prefetch_descriptor(tma_atom_c_latent) - cpasync.prefetch_descriptor(tma_atom_c_rope) - cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) - - # Alloc - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - - # Tensor memory dealloc barrier init - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=self.tmem_ptr_sync_bar, - allocator_warp_id=self.mma_warp_id, - is_two_cta=self.use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, - ) - - load_q_pipeline = self.make_and_init_load_qkv_pipeline( - storage.load_q_mbar_ptr.data_ptr(), - cta_layout_vmnk, - self.load_q_stage, - self.tma_copy_q_bytes, - ) - load_kv_pipeline = self.make_and_init_load_qkv_pipeline( - storage.load_kv_mbar_ptr.data_ptr(), - cta_layout_vmnk, - self.load_kv_stage, - self.tma_copy_kc_bytes, - ) - mma_s_pipeline = self.make_and_init_mma_s_pipeline( - storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk - ) - p_mma_pipeline = self.make_and_init_p_mma_pipeline( - storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk - ) - p_cor_pipeline = self.make_and_init_p_cor_pipeline( - storage.p_cor_mbar_ptr.data_ptr() - ) - mma_o_pipeline = self.make_and_init_mma_o_pipeline( - storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk - ) - load_pt_pipeline = self.make_and_init_load_pt_pipeline( - storage.load_pt_mbar_ptr.data_ptr() - ) - - # Cluster arrive after barrier init - pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) - - # Generate smem tensor Q/KC/VC/exchange - # (MMA, MMA_H, MMA_R, PIPE) - sQ = storage.smem_q_latent.get_tensor( - q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner - ) - sQ_rope = storage.smem_q_rope.get_tensor( - q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner - ) - # (MMA, MMA_K, MMA_R, PIPE) - sKC = storage.smem_kc.get_tensor( - kc_smem_layout_staged.outer, swizzle=kc_smem_layout_staged.inner - ) - sKC_for_tma = storage.smem_kc.get_tensor( - kc_smem_layout_for_tma.outer, - swizzle=kc_smem_layout_for_tma.inner, - ) - # (MMA, MMA_D, MMA_K, PIPE) - # reuse smem - sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner) - sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer) - sVC_for_tma = cute.make_tensor(sVC_ptr, vc_smem_layout_for_tma.outer) - # (MMA, MMA_H, MMA_K) - sP = storage.smem_p.get_tensor( - p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner - ) - sPT = storage.smem_page_table.get_tensor( - cute.make_layout((self.mma_qk_tiler[1] // 2, self.load_pt_stage)) - ) - # (compute_threads,) - softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( - cute.make_layout(self.num_compute_warps * self.threads_per_warp) - ) - epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( - cute.make_layout(self.num_compute_warps * self.threads_per_warp) - ) - - # - # Cluster wait before tensor memory alloc - # - pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) - - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_wait() - - # /////////////////////////////////////////////////////////////////////////////// - # Load warps, including page table and data tensors - # /////////////////////////////////////////////////////////////////////////////// - - if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: - _setmaxregister_decrease(self.other_reg_num) - if warp_idx == self.load_pt_warp_id: - _setmaxregister_decrease(self.other_reg_num) - load_pt_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_pt_stage - ) - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, - cache_seqs, - block_split_kvs, - blk_coord, - ) - if k_tile_count > 0: - load_pt_common_params = SimpleNamespace( - blk_coord=blk_coord, - load_pt_pipeline=load_pt_pipeline, - mPT=mPT, - sPT=sPT, - tidx=tidx, - page_size=mCL.shape[0], - ) - load_pt_producer_state = self.load_page_table( - load_pt_common_params, - k_index, - k_tile_count, - load_pt_producer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - load_pt_pipeline.producer_tail(load_pt_producer_state) - if warp_idx == self.load_tma_warp_id: - _setmaxregister_decrease(self.other_reg_num) - load_q_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_q_stage - ) - load_kv_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_kv_stage - ) - load_pt_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_pt_stage - ) - load_pt_release_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_pt_stage - ) - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, - cache_seqs, - block_split_kvs, - blk_coord, - ) - if k_tile_count > 0: - # Construct fixed common/tma_qk/tma_pv params for load_tma - tma_common_params = SimpleNamespace( - blk_coord=blk_coord, - local_split_kv=local_split_kv, - load_q_pipeline=load_q_pipeline, - load_kv_pipeline=load_kv_pipeline, - mPT=mPT, - sPT=sPT, - load_pt_pipeline=load_pt_pipeline, - ) - tma_qk_params = SimpleNamespace( - tiled_mma_qk=tiled_mma_qk, - tma_atom_q_latent=tma_atom_q_latent, - tma_atom_q_rope=tma_atom_q_rope, - tma_atom_c_latent=tma_atom_c_latent, - tma_atom_c_rope=tma_atom_c_rope, - mQL=mQL, - mQR=mQR, - mCL=mCL, - mKR=mKR, - sQ=sQ, - sQ_rope=sQ_rope, - sKC=sKC_for_tma, - ) - tma_pv_params = SimpleNamespace( - tiled_mma_pv=tiled_mma_pv, - tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, - mCL=mCL, - mKR=mKR, - mCLT=mCLT, - sVC=sVC_for_tma, - ) - # Load tma - ( - load_q_producer_state, - load_kv_producer_state, - load_pt_consumer_state, - load_pt_release_state, - ) = self.load_tma( - tma_common_params, - tma_qk_params, - tma_pv_params, - k_index, - k_tile_count, - load_q_producer_state, - load_kv_producer_state, - load_pt_consumer_state, - load_pt_release_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - load_q_pipeline.producer_tail(load_q_producer_state) - load_kv_pipeline.producer_tail(load_kv_producer_state) - - # /////////////////////////////////////////////////////////////////////////////// - # MMA warp - # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.mma_warp_id: - _setmaxregister_decrease(self.other_reg_num) - # Alloc tensor memory buffer - tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - - load_q_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_q_stage - ) - load_kv_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_kv_stage - ) - mma_s_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_s_stage - ) - p_mma_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.p_mma_stage - ) - mma_o_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_o_stage - ) - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, cache_seqs, block_split_kvs, blk_coord - ) - if k_tile_count > 0: - mma_common_params = SimpleNamespace( - blk_coord=blk_coord, - local_split_kv=local_split_kv, - load_q_pipeline=load_q_pipeline, - load_kv_pipeline=load_kv_pipeline, - tmem_ptr=tmem_ptr, - is_leader_cta=is_leader_cta, - L=mCL.shape[1], - ) - mma_qk_params = SimpleNamespace( - mma_s_pipeline=mma_s_pipeline, - sQ=sQ, - sQ_rope=sQ_rope, - sKC=sKC, - ) - mma_pv_params = SimpleNamespace( - p_mma_pipeline=p_mma_pipeline, - mma_o_pipeline=mma_o_pipeline, - sP=sP, - sVC=sVC, - ) - ( - tiled_mma_qk, - tiled_mma_pv, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma( - mma_common_params, - mma_qk_params, - mma_pv_params, - k_tile_count, - tiled_mma_qk, - tiled_mma_pv, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - mma_s_pipeline.producer_tail(mma_s_producer_state) - mma_o_pipeline.producer_tail(mma_o_producer_state) - - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_launch_dependents() - - # /////////////////////////////////////////////////////////////////////////////// - # Compute warp - # /////////////////////////////////////////////////////////////////////////////// - if ( - warp_idx >= self.compute_warp_ids[0] - and warp_idx <= self.compute_warp_ids[-1] - ): - _setmaxregister_increase(self.softmax_reg_num) - mma_s_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_s_stage - ) - p_mma_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.p_mma_stage - ) - p_cor_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.p_cor_stage - ) - mma_o_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_o_stage - ) - # sync with mma warp before retrieving tmem ptr - tmem.wait_for_alloc() - - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, cache_seqs, block_split_kvs, blk_coord - ) - if k_tile_count > 0: - compute_common_params = SimpleNamespace( - blk_coord=blk_coord, - split_kv=split_kv, - local_split_kv=local_split_kv, - smem_exchange=softmax_smem_exchange, - mAccO=mAccO, - mO=mO, - K=cache_seqs[blk_coord[2]], - L=mCL.shape[1], - tmem_ptr=tmem_ptr, - tidx=tidx, - p_cor_pipeline=p_cor_pipeline, - ) - compute_softmax_params = SimpleNamespace( - tiled_mma_qk=tiled_mma_qk, - sP=sP, - mma_s_pipeline=mma_s_pipeline, - p_mma_pipeline=p_mma_pipeline, - softmax_scale_log2=softmax_scale_log2, - ) - mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( - self.compute( - compute_common_params, - compute_softmax_params, - k_index=k_index, - k_tile_count=k_tile_count, - mma_s_consumer_state=mma_s_consumer_state, - p_mma_producer_state=p_mma_producer_state, - p_cor_producer_state=p_cor_producer_state, - ) - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - p_cor_pipeline.producer_tail(p_cor_producer_state) - - # /////////////////////////////////////////////////////////////////////////////// - # Correction warp - # /////////////////////////////////////////////////////////////////////////////// - if ( - warp_idx >= self.correction_warp_ids[0] - and warp_idx <= self.correction_warp_ids[-1] - ): - _setmaxregister_increase(self.correction_reg_num) - p_cor_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.p_cor_stage - ) - mma_o_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_o_stage - ) - # sync with mma warp before retrieving tmem ptr - tmem.wait_for_alloc() - - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, cache_seqs, block_split_kvs, blk_coord - ) - if k_tile_count > 0: - compute_common_params = SimpleNamespace( - blk_coord=blk_coord, - split_kv=split_kv, - local_split_kv=local_split_kv, - smem_exchange=epilogue_smem_exchange, - mAccO=mAccO, - mO=mO, - K=cache_seqs[blk_coord[2]], - L=mCL.shape[1], - H=mQL.shape[0], - tmem_ptr=tmem_ptr, - tidx=tidx, - tiled_mma_pv=tiled_mma_pv, - p_cor_pipeline=p_cor_pipeline, - mma_o_pipeline=mma_o_pipeline, - ) - compute_epilogue_params = SimpleNamespace( - output_scale=output_scale, - softmax_scale_log2=softmax_scale_log2, - mAccLSE=mAccLSE, - mLSE=mLSE, - ) - p_cor_consumer_state, mma_o_consumer_state = self.correction( - compute_common_params, - compute_epilogue_params, - k_tile_count=k_tile_count, - p_cor_consumer_state=p_cor_consumer_state, - mma_o_consumer_state=mma_o_consumer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - return - - @cute.kernel - def reduction_kernel( - self, - mO: cute.Tensor, - mLSE: cute.Tensor, - mAccO: cute.Tensor, - mAccLSE: cute.Tensor, - split_kv: cutlass.Int32, - cache_seqs: cute.Tensor, - block_split_kvs: cute.Tensor, - ): - """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results - from multiple split_kv blocks into final outputs. - - :param mO: Output tensor for storing final results - :type mO: cute.Tensor - :param mLSE: Log-sum-exp tensor for storing final LSE values - :type mLSE: cute.Tensor - :param mAccO: Accumulated output tensor from split_kv blocks - :type mAccO: cute.Tensor - :param mAccLSE: Accumulated LSE tensor from split_kv blocks - :type mAccLSE: cute.Tensor - :param split_kv: Number of split_kv blocks - :type split_kv: cutlass.Int32 - :param cache_seqs: Cache sequence lengths tensor - :type cache_seqs: cute.Tensor - :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) - :type block_split_kvs: cute.Tensor - """ - bidx, bidy, bidz = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() - blk_coord = (bidx, bidy, bidz) - local_split_kv = ( - block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv - ) - k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) - k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) - local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) - - # Alloc shared memory - smem = utils.SmemAllocator() - storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) - lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) - smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) - - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_wait() - gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if warp_idx == 0: - # calculate the global lse and exp ^ (local_lse - global_lse) - lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) - - local_lse = cute.make_rmem_tensor( - cute.make_layout(lse_per_thread), self.lse_dtype - ) - lse_max = -self.lse_dtype.inf - # find the max lse - for i in cutlass.range_constexpr(lse_per_thread): - split_kv_idx = tidx + i * self.threads_per_warp - local_lse[i] = ( - gLSE[split_kv_idx] - if cute.elem_less(split_kv_idx, local_split_kv) - else -self.lse_dtype.inf - ) - # reduce the local lse - lse_max = cute.arch.fmax(lse_max, local_lse[i]) - lse_max = cute.arch.warp_reduction_max(lse_max) - lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 - # calculate sum_lse - sum_lse = 0.0 - for i in cutlass.range_constexpr(lse_per_thread): - sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) - sum_lse = cute.arch.warp_reduction_sum(sum_lse) - # calculate the global_lse - global_lse = ( - lse_max + cute.math.log2(sum_lse, fastmath=True) - if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 - else self.lse_dtype.inf - ) - if tidx == 0: - mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse - # store the scale to shared memory - for i in cutlass.range_constexpr(lse_per_thread): - split_kv_idx = tidx + i * self.threads_per_warp - if cute.elem_less(split_kv_idx, local_split_kv): - smem_lse_scale[split_kv_idx] = cute.math.exp2( - local_lse[i] - global_lse, fastmath=True - ) - - pipeline.sync(barrier_id=4) - - elements_per_thread = cute.ceil_div( - self.latent_dim, self.threads_per_warp * self.num_compute_warps - ) - gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] - rAccO = cute.make_rmem_tensor( - cute.make_layout(elements_per_thread), self.acc_dtype - ) - rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) - rAccO.fill(0.0) - for i in range(local_split_kv): - for j in cutlass.range_constexpr(elements_per_thread): - element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps - rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] - rO.store(rAccO.load().to(self.o_dtype)) - for j in cutlass.range_constexpr(elements_per_thread): - element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps - mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_launch_dependents() - return - - @staticmethod - def get_split_kv( - B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int - ) -> int: - """Get the proper split_kv value for the MLA kernel based on parameters. - - :param B: Batch size - :type B: int - :param S: Sequence length - :type S: int - :param K: Sequence length - :type K: int - :param mma_qk_tiler_mn: MLA tiling parameters - :type mma_qk_tiler_mn: tuple - :param max_active_blocks: Maximum number of active blocks - :type max_active_blocks: int - :return: Split_kv value - :rtype: int - """ - max_splits = ceil_div(K, mma_qk_tiler_mn[1]) - blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) - split_heur = min(max_splits, blocks_per_batch) - # {$nv-internal-release begin} - # TODO: figure out the error of make_tile with dynamic int_tuple - # {$nv-internal-release end} - k_waves = ceil_div(max_splits, split_heur) - split_wave_aware = ceil_div(max_splits, k_waves) - max_split_kv = 32 - return min(split_wave_aware, max_split_kv) - - @staticmethod - def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: - blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) - max_split_kv = 32 - return min(blocks_per_batch, max_split_kv) - - @cute.jit - def get_k_tile_count( - self, - split_kv: cutlass.Int32, - cache_seqs: cute.Tensor, - block_split_kvs: cute.Tensor, - blk_coord: cute.Coord, - ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: - """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. - - :param split_kv: Split_kv value - :type split_kv: cutlass.Int32 - :param cache_seqs: Cache sequence lengths tensor - :type cache_seqs: cute.Tensor - :param block_split_kvs: Per-block split_kv values tensor - :type block_split_kvs: cute.Tensor - :param blk_coord: Block coordinate - :type blk_coord: cute.Coord - :return: k_index, k_tile_count, split_kv - :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] - """ - K = cache_seqs[blk_coord[2]] - if cutlass.const_expr(self.is_var_split_kv): - split_kv = block_split_kvs[blk_coord[2]] - - k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) - # {$nv-internal-release begin} - # TODO: figure out the error of make_tile with dynamic int_tuple - # {$nv-internal-release end} - k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) - k_index = blk_coord[3] * k_tile_per_cta - k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) - return k_index, k_tile_count, split_kv - - @cute.jit - def load_page_table( - self, - common_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - load_pt_producer_state: pipeline.PipelineState, - ) -> pipeline.PipelineState: - """Load warp to load page table. Updates the load pt producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param load_pt_producer_state: The load pt producer state - :type load_pt_producer_state: pipeline.PipelineState - - :return: The load pt producer state - :rtype: pipeline.PipelineState - """ - mPT = common_params.mPT[None, common_params.blk_coord[2]] - page_per_tile = self.mma_qk_tiler[1] // self.page_size - tidx = common_params.tidx % self.threads_per_warp - - load_pt_pipeline = common_params.load_pt_pipeline - while k_tile_count > 0: - load_pt_pipeline.producer_acquire(load_pt_producer_state) - - elem_per_thread = cute.ceil_div(page_per_tile, self.threads_per_warp) - - # atom_async_copy: async copy atom for page table load - atom_async_copy = cute.make_copy_atom( - cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), - cutlass.Int32, - num_bits_per_copy=cutlass.Int32.width, - ) - mPT_for_copy = cute.flat_divide(mPT, (1,)) - sPT_for_copy = cute.flat_divide(common_params.sPT, (1,)) - # elem_per_thread is a dynamic value depends on the page_size setting. - for i in range(elem_per_thread): - idx = i * self.threads_per_warp + tidx - if cute.elem_less( - k_index * page_per_tile + idx, mPT.shape[0] - ) and cute.elem_less(idx, page_per_tile): - cute.copy( - atom_async_copy, - mPT_for_copy[None, k_index * page_per_tile + idx], - sPT_for_copy[None, idx, load_pt_producer_state.index], - ) - else: - sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0) - mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) # noqa: F841 - load_pt_pipeline.producer_commit(load_pt_producer_state) - load_pt_producer_state.advance() - k_index += 1 - k_tile_count -= 1 - - return load_pt_producer_state - - @cute.jit - def load_tma( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - v_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - load_q_producer_state: pipeline.PipelineState, - load_kv_producer_state: pipeline.PipelineState, - load_pt_consumer_state: pipeline.PipelineState, - load_pt_release_state: pipeline.PipelineState, - ) -> tuple[ - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """Load wrap to load Q/C latent/rope tensors. Updates the load qkv producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param qk_params: The qk parameters - :type qk_params: SimpleNamespace - :param v_params: The v parameters - :type v_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param load_q_producer_state: The load q producer state - :type load_q_producer_state: pipeline.PipelineState - :param load_kv_producer_state: The load kv producer state - :type load_kv_producer_state: pipeline.PipelineState - :param load_pt_consumer_state: The load pt consumer state - :type load_pt_consumer_state: pipeline.PipelineState - :param load_pt_release_state: The load pt release state - :type load_pt_release_state: pipeline.PipelineState - - :return: The load q producer state, load kv producer state, load pt consumer state, and load pt release state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - # page table - mPT = common_params.mPT[None, common_params.blk_coord[2]] - - # Flatten divide and partition global tensors for QK TMA load - # (bM, bK, rM, rK, rL) - mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) - gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) - mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) - gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) - - thr_mma_qk = qk_params.tiled_mma_qk.get_slice( - common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) - ) - tSgQL = thr_mma_qk.partition_A(gQL) - tSgQR = thr_mma_qk.partition_A(gQR) - - cta_m = min( - qk_params.tiled_mma_qk.op.shape_mnk[0] - // qk_params.tiled_mma_qk.thr_id.shape, - self.page_size, - ) - page_tile_size = min(self.page_size, cta_m) - gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) - tSgCL = ( - gCL[ - None, - common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, - None, - None, - ] - if cta_m < self.page_size - else gCL[None, 0, None, None] - ) - gKR = cute.tiled_divide(qk_params.mKR, (page_tile_size, self.mma_qk_tiler[2])) - tSgKR = ( - gKR[ - None, - common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, - None, - None, - ] - if cta_m < self.page_size - else gKR[None, 0, None, None] - ) - - # tma partition for q, k latent/rope - # smem: ((atom_v, rest_v), STAGE) - # gmem: ((atom_v, rest_v), RestM, RestK, RestL) - tQsQ, tQLgQL_mkl = cpasync.tma_partition( - qk_params.tma_atom_q_latent, - 0, - cute.make_layout(1), - cute.group_modes(qk_params.sQ, 0, 3), - cute.group_modes(tSgQL, 0, 3), - ) - - tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( - qk_params.tma_atom_q_rope, - 0, - cute.make_layout(1), - cute.group_modes(qk_params.sQ_rope, 0, 3), - cute.group_modes(tSgQR, 0, 3), - ) - - tKCsKC, tCLgCL = cpasync.tma_partition( - qk_params.tma_atom_c_latent, - 0, - cute.make_layout(1), - qk_params.sKC, - tSgCL, - ) - - _, tKRgKR = cpasync.tma_partition( - qk_params.tma_atom_c_rope, - 0, - cute.make_layout(1), - qk_params.sKC, - tSgKR, - ) - - tQLgQL = tQLgQL_mkl[ - None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] - ] - tQRgQR = tQRgQR_mkl[ - None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] - ] - - # Flatten divide and partition global tensors for V TMA load - page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) - gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) - cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape - gCLT = cute.logical_divide(gCLT, (cta_n,))[ - (None, common_params.blk_coord[0]), None, None, None, None - ] - tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) - tOgCLT = tOgCLT[None, 0, 0, None, None, None] - - # tma partition for vc - # smem: ((atom_v, rest_v), STAGE) - # gmem: ((atom_v, rest_v), RestM, RestK, RestL) - tVCsVC, tCLTgCLT = cpasync.tma_partition( - v_params.tma_atom_c_latent_transpose, - 0, - cute.make_layout(1), - v_params.sVC, - tOgCLT, - ) - - # set extra params - common_params.mPT = mPT - qk_params.tQLgQL = tQLgQL - qk_params.tQRgQR = tQRgQR - qk_params.tCLgCL = tCLgCL - qk_params.tKRgKR = tKRgKR - qk_params.tQsQ = tQsQ - qk_params.tQsQ_rope = tQsQ_rope - qk_params.tKCsKC = tKCsKC - v_params.tCLTgCLT = tCLTgCLT - v_params.tVCsVC = tVCsVC - - load_q_producer_state, load_kv_producer_state, load_pt_consumer_state = ( - self.load_tma_qk_one_k_tile( - common_params, - qk_params, - k_index, - k_tile_count, - load_q_producer_state, - load_kv_producer_state, - load_pt_consumer_state, - load_q=True, - ) - ) - k_index += 1 - k_tile_count -= 1 - while k_tile_count > 0: - # {$nv-internal-release begin} - # TODO: figure out how to support SingleNamespace/struct in ast - # {$nv-internal-release end} - load_q_producer_state, load_kv_producer_state, load_pt_consumer_state = ( - self.load_tma_qk_one_k_tile( - common_params, - qk_params, - k_index, - k_tile_count, - load_q_producer_state, - load_kv_producer_state, - load_pt_consumer_state, - load_q=False, - ) - ) - load_kv_producer_state, load_pt_release_state = self.load_tma_v_one_k_tile( - common_params, - v_params, - k_index - 1, - load_kv_producer_state, - load_pt_release_state, - ) - k_index += 1 - k_tile_count -= 1 - - # load last v tile - load_kv_producer_state, load_pt_release_state = self.load_tma_v_one_k_tile( - common_params, - v_params, - k_index - 1, - load_kv_producer_state, - load_pt_release_state, - ) - return ( - load_q_producer_state, - load_kv_producer_state, - load_pt_consumer_state, - load_pt_release_state, - ) - - @cute.jit - def load_tma_qk_one_k_tile( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - load_q_producer_state: pipeline.PipelineState, - load_kv_producer_state: pipeline.PipelineState, - load_pt_consumer_state: pipeline.PipelineState, - load_q: bool, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: - """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param qk_params: The qk parameters - :type qk_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param load_q_producer_state: The load q producer state - :type load_q_producer_state: pipeline.PipelineState - :param load_kv_producer_state: The load kv producer state - :type load_kv_producer_state: pipeline.PipelineState - :param load_pt_consumer_state: The load pt consumer state - :type load_pt_consumer_state: pipeline.PipelineState - :param load_q: Whether to load q - :type load_q: bool - - :return: The load q producer state, load kv producer state, and load pt consumer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - page_per_tile = ceil_div( - self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape - ) - common_params.load_pt_pipeline.consumer_wait(load_pt_consumer_state) - page_table_stage = load_pt_consumer_state.index - load_pt_consumer_state.advance() - k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) - for i in cutlass.range_constexpr(page_per_tile): - k_idx[i] = ( - common_params.sPT[0, page_table_stage] - if self.mma_qk_tiler[1] // self.page_size == 1 - else common_params.sPT[ - i + common_params.blk_coord[0] * page_per_tile, page_table_stage - ] - ) - # load q once at first iteration - if cutlass.const_expr(load_q): - common_params.load_q_pipeline.producer_acquire(load_q_producer_state) - # get the mbar ptr from pipeline. - tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( - load_q_producer_state - ) - for i in cutlass.range(self.iterations_qk_latent): - # load q latent - cute.copy( - qk_params.tma_atom_q_latent, - qk_params.tQLgQL[None, 0, i], - qk_params.tQsQ[None, (i, 0)], - tma_bar_ptr=tma_bar_ptr, - ) - for i in cutlass.range(self.iterations_qk_rope): - # load q rope - cute.copy( - qk_params.tma_atom_q_rope, - qk_params.tQRgQR[None, 0, i], - qk_params.tQsQ_rope[None, i], - tma_bar_ptr=tma_bar_ptr, - ) - load_q_producer_state.advance() - load_kv_pipeline = common_params.load_kv_pipeline - tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) - for i in cutlass.range(self.iterations_qk_latent): - # get the mbar ptr from pipeline. - tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) - load_kv_pipeline.producer_acquire(load_kv_producer_state) - for k in cutlass.range(page_per_tile): - # load k latent - cute.copy( - qk_params.tma_atom_c_latent, - qk_params.tCLgCL[None, i, k_idx[k]], - qk_params.tKCsKC[None, k, 0, load_kv_producer_state.index], - tma_bar_ptr=tma_bar_ptr, - ) - load_kv_producer_state.advance() - - for i in cutlass.range(self.iterations_qk_rope): - # get the mbar ptr from pipeline. - tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) - load_kv_pipeline.producer_acquire(load_kv_producer_state) - for k in cutlass.range(page_per_tile): - # load k rope - cute.copy( - qk_params.tma_atom_c_rope, - qk_params.tKRgKR[None, i, k_idx[k]], - qk_params.tKCsKC[None, k, 0, load_kv_producer_state.index], - tma_bar_ptr=tma_bar_ptr, - ) - load_kv_producer_state.advance() - - return load_q_producer_state, load_kv_producer_state, load_pt_consumer_state - - @cute.jit - def load_tma_v_one_k_tile( - self, - common_params: SimpleNamespace, - v_params: SimpleNamespace, - k_index: cutlass.Int32, - load_kv_producer_state: pipeline.PipelineState, - load_pt_release_state: pipeline.PipelineState, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: - """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param v_params: The load tma v parameters - :type v_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param load_kv_producer_state: The load qkv producer state - :type load_kv_producer_state: pipeline.PipelineState - :param load_pt_release_state: The load pt release state - :type load_pt_release_state: pipeline.PipelineState - - :return: The load kv producer state and load pt release state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] - """ - page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size - page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) - k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) - page_table_stage = load_pt_release_state.index - for i in cutlass.range(page_per_tile): - k_idx[i] = ( - common_params.sPT[0, page_table_stage] - if page_per_tile == 1 - else common_params.sPT[i, page_table_stage] - ) - common_params.load_pt_pipeline.consumer_release(load_pt_release_state) - load_pt_release_state.advance() - load_kv_pipeline = common_params.load_kv_pipeline - tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) - for i in cutlass.range(self.iterations_pv_k): - for j in cutlass.range(self.iterations_pv_n): - # get the mbar ptr from pipeline. - tma_bar_ptr = load_kv_pipeline.producer_get_barrier( - load_kv_producer_state - ) - load_kv_pipeline.producer_acquire(load_kv_producer_state) - for k in cutlass.range(page_per_subtile): - k_idx_i = k_idx[ - k - + i - // ceil_div(self.iterations_pv_k, page_per_tile) - * page_per_subtile - ] - cute.copy( - v_params.tma_atom_c_latent_transpose, - v_params.tCLTgCLT[ - None, - j, - i % ceil_div(self.iterations_pv_k, page_per_tile), - k_idx_i, - ], - v_params.tVCsVC[None, 0, k, load_kv_producer_state.index], - tma_bar_ptr=tma_bar_ptr, - ) - - load_kv_producer_state.advance() - return load_kv_producer_state, load_pt_release_state - - @cute.jit - def mma( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - pv_params: SimpleNamespace, - k_tile_count: cutlass.Int32, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - load_q_consumer_state: pipeline.PipelineState, - load_kv_consumer_state: pipeline.PipelineState, - mma_s_producer_state: pipeline.PipelineState, - p_mma_consumer_state: pipeline.PipelineState, - mma_o_producer_state: pipeline.PipelineState, - ) -> tuple[ - cute.TiledMma, - cute.TiledMma, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. - - :param common_params: The common parameters for mma qk and pv - :type common_params: SimpleNamespace - :param qk_params: The mma qk parameters - :type qk_params: SimpleNamespace - :param pv_params: The mma pv parameters - :type pv_params: SimpleNamespace - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param tiled_mma_qk: The tiled mma qk - :type tiled_mma_qk: cute.TiledMma - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param load_q_consumer_state: The load q consumer state - :type load_q_consumer_state: pipeline.PipelineState - :param load_kv_consumer_state: The load kv consumer state - :type load_kv_consumer_state: pipeline.PipelineState - :param mma_s_producer_state: The mma s producer state - :type mma_s_producer_state: pipeline.PipelineState - :param p_mma_consumer_state: The p mma consumer state - :type p_mma_consumer_state: pipeline.PipelineState - :param mma_o_producer_state: The mma o producer state - :type mma_o_producer_state: pipeline.PipelineState - - :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load kv consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state - :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - - tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) - tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) - tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) - tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) - tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) - - tStS_shape = tiled_mma_qk.partition_shape_C( - cute.select(self.mma_qk_tiler, mode=[0, 1]) - ) - tStS_staged_fake = tiled_mma_qk.make_fragment_C( - cute.append(tStS_shape, self.mma_s_stage) - ) - # use real tmem ptr for tStS - tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) - tOtO_shape = tiled_mma_pv.partition_shape_C( - cute.select(self.mma_pv_tiler, mode=[0, 1]) - ) - # mma O has 1 stage. - tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) - tOtO_layout = cute.append( - tOtO.layout, - cute.make_layout( - common_params.L // self.mma_pv_tiler[1], - stride=self.mma_pv_tiler[1] // self.warps_in_n, - ), - ) - tOtO_staged = cute.make_tensor( - tStS_staged.iterator + self.tmem_o_offset, tOtO_layout - ) - - # set more parameters - qk_params.tSrQ = tSrQ - qk_params.tSrQ_rope = tSrQ_rope - qk_params.tSrKC = tSrKC - qk_params.tStS_staged = tStS_staged - pv_params.tOrP = tOrP - pv_params.tOrVC = tOrVC - pv_params.tOtO_staged = tOtO_staged - - # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) - load_q_pipeline = common_params.load_q_pipeline - if common_params.is_leader_cta: - load_q_release_state = load_q_consumer_state.clone() - - ( - tiled_mma_qk, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - ) = self.mma_qk( - common_params, - qk_params, - tiled_mma_qk, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - wait_q=True, - ) - k_tile_count -= 1 - while k_tile_count > 0: - ( - tiled_mma_qk, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - ) = self.mma_qk( - common_params, - qk_params, - tiled_mma_qk, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - wait_q=False, - ) - ( - tiled_mma_pv, - load_kv_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma_pv( - common_params, - pv_params, - tiled_mma_pv, - load_kv_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - k_tile_count -= 1 - - # release q consumer states - load_q_pipeline.consumer_release(load_q_release_state) - load_q_release_state.advance() - ( - tiled_mma_pv, - load_kv_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma_pv( - common_params, - pv_params, - tiled_mma_pv, - load_kv_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - - return ( # type: ignore[return-value] - tiled_mma_qk, - tiled_mma_pv, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - - @cute.jit - def mma_qk( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - tiled_mma_qk: cute.TiledMma, - load_q_consumer_state: pipeline.PipelineState, - load_kv_consumer_state: pipeline.PipelineState, - mma_s_producer_state: pipeline.PipelineState, - wait_q: bool, - ) -> tuple[ - cute.TiledMma, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. - - :param qk_params: The qk parameters - :type qk_params: SimpleNamespace - :param tiled_mma_qk: The tiled mma qk - :type tiled_mma_qk: cute.TiledMma - :param load_q_consumer_state: The load q consumer state - :type load_q_consumer_state: pipeline.PipelineState - :param load_kv_consumer_state: The load kv consumer state - :type load_kv_consumer_state: pipeline.PipelineState - :param mma_s_producer_state: The mma s producer state - :type mma_s_producer_state: pipeline.PipelineState - - :return: The tiled mma qk, the load q consumer state, the load kv consumer state, and the mma s producer state - :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] - - qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) - tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) - load_q_pipeline = common_params.load_q_pipeline - load_kv_pipeline = common_params.load_kv_pipeline - if cutlass.const_expr(wait_q): - load_q_pipeline.consumer_wait(load_q_consumer_state) - load_q_consumer_state.advance() - for q_stage in range(self.iterations_qk_latent): - load_kv_pipeline.consumer_wait(load_kv_consumer_state) - kc_stage = load_kv_consumer_state.index - for k_block in cutlass.range(cute.size(qk_params.tSrQ.shape[2])): - cute.gemm( - tiled_mma_qk, - tStS, - qk_params.tSrQ[None, None, k_block, q_stage], - qk_params.tSrKC[None, None, k_block, kc_stage], - tStS, - ) - tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) - load_kv_pipeline.consumer_release(load_kv_consumer_state) - load_kv_consumer_state.advance() - for q_stage in range(self.iterations_qk_rope): - load_kv_pipeline.consumer_wait(load_kv_consumer_state) - kc_stage = load_kv_consumer_state.index - for k_block in cutlass.range(self.rope_dim // tiled_mma_qk.shape_mnk[2]): - cute.gemm( - tiled_mma_qk, - tStS, - qk_params.tSrQ_rope[None, None, k_block, q_stage], - qk_params.tSrKC[None, None, k_block, kc_stage], - tStS, - ) - tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) - load_kv_pipeline.consumer_release(load_kv_consumer_state) - load_kv_consumer_state.advance() - - qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) - mma_s_producer_state.advance() - return ( - tiled_mma_qk, - load_q_consumer_state, - load_kv_consumer_state, - mma_s_producer_state, - ) - - @cute.jit - def mma_pv( - self, - common_params: SimpleNamespace, - pv_params: SimpleNamespace, - tiled_mma_pv: cute.TiledMma, - load_kv_consumer_state: pipeline.PipelineState, - p_mma_consumer_state: pipeline.PipelineState, - mma_o_producer_state: pipeline.PipelineState, - ) -> tuple[ - cute.TiledMma, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param pv_params: The pv parameters - :type pv_params: SimpleNamespace - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param load_kv_consumer_state: The load kv consumer state - :type load_kv_consumer_state: pipeline.PipelineState - :param p_mma_consumer_state: The P MMA consumer state - :type p_mma_consumer_state: pipeline.PipelineState - :param mma_o_producer_state: The MMA o producer state - :type mma_o_producer_state: pipeline.PipelineState - - :return: The tiled mma pv, the load qkv consumer state, the P MMA consumer state, and the MMA o producer state - :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - - pv_params.mma_o_pipeline.producer_acquire(mma_o_producer_state) - pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) - load_kv_pipeline = common_params.load_kv_pipeline - for p_stage in range(self.iterations_pv_k): - accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) - for acc_stage in range(self.iterations_pv_n): - load_kv_pipeline.consumer_wait(load_kv_consumer_state) - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) - vc_stage = load_kv_consumer_state.index - tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] - for k_block in cutlass.range(pv_params.tOrP.shape[2]): - cute.gemm( - tiled_mma_pv, - tOtO, - pv_params.tOrP[ - None, - None, - k_block, - (p_stage, p_mma_consumer_state.index), - ], - pv_params.tOrVC[None, None, k_block, vc_stage], - tOtO, - ) - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) - load_kv_pipeline.consumer_release(load_kv_consumer_state) - load_kv_consumer_state.advance() - pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) - p_mma_consumer_state.advance() - pv_params.mma_o_pipeline.producer_commit(mma_o_producer_state) - mma_o_producer_state.advance() - - return ( - tiled_mma_pv, - load_kv_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - - @cute.jit - def compute( - self, - common_params: SimpleNamespace, - softmax_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - mma_s_consumer_state: pipeline.PipelineState, - p_mma_producer_state: pipeline.PipelineState, - p_cor_producer_state: pipeline.PipelineState, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: - """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param softmax_params: The softmax parameters - :type softmax_params: SimpleNamespace - :param k_index: The index of the k-tile - :type k_index: cutlass.Int32 - :param k_tile_count: The number of k-tiles - :type k_tile_count: cutlass.Int32 - :param mma_s_consumer_state: The MMA s consumer state - :type mma_s_consumer_state: pipeline.PipelineState - :param p_mma_producer_state: The P MMA producer state - :type p_mma_producer_state: pipeline.PipelineState - :param p_cor_producer_state: The P correction producer state - :type p_cor_producer_state: pipeline.PipelineState - - :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - - k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) - - row_max = -self.acc_dtype.inf - row_sum = self.acc_dtype(0) - correction_factor = self.acc_dtype(1) - common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - - # no mask applied - while k_tile_count > 1: - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - False, - False, - ) - k_index = k_index + 1 - k_tile_count = k_tile_count - 1 - - # mask applied - if cutlass.const_expr(common_params.mAccO is not None): - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - k_index == k_tile_total - 1, - True, - ) - else: - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - True, - True, - ) - - return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state - - @cute.jit - def correction( - self, - common_params: SimpleNamespace, - epilogue_params: SimpleNamespace, - k_tile_count: cutlass.Int32, - p_cor_consumer_state: pipeline.PipelineState, - mma_o_consumer_state: pipeline.PipelineState, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: - """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param epilogue_params: The epilogue parameters - :type epilogue_params: SimpleNamespace - :param k_index: The index of the k-tile - :type k_index: cutlass.Int32 - :param k_tile_count: The number of k-tiles - :type k_tile_count: cutlass.Int32 - :param p_cor_consumer_state: The P correction consumer state - :type p_cor_consumer_state: pipeline.PipelineState - :param mma_o_consumer_state: The MMA o consumer state - :type mma_o_consumer_state: pipeline.PipelineState - - :return: The P correction consumer state, and the MMA o consumer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] - """ - - k_tile_count_init = k_tile_count - while k_tile_count > 0: - p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( - self.get_correction_factor(common_params, p_cor_consumer_state) - ) - if k_tile_count_init != k_tile_count: - mma_o_consumer_state = self.rescale( - common_params, - mma_o_consumer_state, - correction_factor, - no_correction, - ) - k_tile_count = k_tile_count - 1 - if k_tile_count == 0: - mma_o_consumer_state = self.epilogue( - common_params, - epilogue_params, - mma_o_consumer_state, - row_sum, - row_max, - ) - - return p_cor_consumer_state, mma_o_consumer_state - - @cute.jit - def exchange_p_cor_metadata( - self, - common_params: SimpleNamespace, - softmax_params: SimpleNamespace, - correction_factor: cutlass.Float32, - row_sum: cutlass.Float32, - row_max: cutlass.Float32, - row_max_new: cutlass.Float32, - tAcc: cute.Tensor, - tidx: cutlass.Int32, - p_cor_producer_state: pipeline.PipelineState, - ) -> pipeline.PipelineState: - """Compute the correction factor for the last k tile.""" - no_correction = 0 - if ( - row_max_new - row_max - ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: - no_correction = 1 - row_max_new = row_max - - # pad for 4x32b - corr_layout = cute.make_layout( - (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), - stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), - ) - tCor = cute.make_tensor( - common_params.tmem_ptr + self.correction_factor_offset, - corr_layout, - ) - cCor = cute.make_identity_tensor(tCor.shape) - corr_tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype - ) - corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) - corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) - cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) - tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) - rCor = cute.make_fragment_like( - cCor_for_copy[None, None, None, 0], self.acc_dtype - ) - rCor_int = cute.make_tensor( - cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout - ) - rCor[0] = row_sum - rCor[1] = row_max_new - rCor[2] = correction_factor - rCor_int[3] = no_correction - - cute.copy( - corr_tmem_store_tiled_copy, - rCor, - tCor_for_copy[None, None, None, p_cor_producer_state.index], - ) - # fence between tmem store and correction warp - cute.arch.fence_view_async_tmem_store() - common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) - p_cor_producer_state.advance() - return p_cor_producer_state, row_max_new - - @cute.jit - def softmax( - self, - common_params: SimpleNamespace, - softmax_params: SimpleNamespace, - k_index: cutlass.Int32, - mma_s_consumer_state: pipeline.PipelineState, - p_mma_producer_state: pipeline.PipelineState, - p_cor_producer_state: pipeline.PipelineState, - row_max: cutlass.Float32, - row_sum: cutlass.Float32, - correction_factor: cutlass.Float32, - is_last_tile: bool, - is_local_last_tile: cutlass.Boolean, - ) -> tuple[ - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - cutlass.Float32, - cutlass.Float32, - cutlass.Float32, - ]: - """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param softmax_params: The softmax parameters - :type softmax_params: SimpleNamespace - :param k_index: The index of the k-tile - :type k_index: cutlass.Int32 - :param mma_s_consumer_state: The MMA s consumer state - :type mma_s_consumer_state: pipeline.PipelineState - :param p_mma_producer_state: The P MMA producer state - :type p_mma_producer_state: pipeline.PipelineState - :param p_cor_producer_state: The P correction producer state - :type p_cor_producer_state: pipeline.PipelineState - :param row_max: The row max - :type row_max: cutlass.Float32 - :param row_sum: The row sum - :type row_sum: cutlass.Float32 - :param correction_factor: The correction factor - :type correction_factor: cutlass.Float32 - :param is_last_tile: Whether the last tile - :type is_last_tile: bool - :param is_local_last_tile: Whether the last tile is local - :type is_local_last_tile: cutlass.Boolean - - :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] - """ - - softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) - softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) - - # load S from tmem - tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( - cute.select(self.mma_qk_tiler, mode=[0, 1]) - ) - tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( - cute.append(tStS_shape, self.mma_s_stage) - ) - tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) - tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] - - tAcc = tStS[(None, None), 0, 0] - cta_qk_tiler = ( - self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], - self.mma_qk_tiler[1], - self.mma_qk_tiler[2], - ) - cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype - ) - tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) - - tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - - tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) - tTR_tAcc = tmem_thr_copy.partition_S(tAcc) - tTR_tS = tmem_thr_copy.partition_D(cS) - - tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) - - row_max_new = row_max - arch = BaseDSL._get_dsl().get_arch_enum() - if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): - cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) - for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): - if is_last_tile: - tTR_rAcc[i] = ( - tTR_rAcc[i] - if cute.elem_less( - tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, - ) - else -self.acc_dtype.inf - ) - # reduction for row_max - row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): - tmem_load_red_atom = cute.make_copy_atom( - tcgen05.copy.LdRed32x32bOp( - tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX - ), - self.acc_dtype, - ) - tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) - tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) - tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) - tTR_tS_red = tmem_red_thr_copy.partition_D(cS) - tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) - tTR_rMax = cute.make_rmem_tensor( - cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), - self.acc_dtype, - ) - cute.copy( - tmem_red_tiled_copy, - tTR_tAcc_red, - (tTR_rAcc_red, tTR_rMax), - ) - tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) - if is_last_tile: - for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): - tTR_rAcc[i] = ( - tTR_rAcc[i] - if cute.elem_less( - tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, - ) - else -self.acc_dtype.inf - ) - # reduction for row_max - row_max_new = tTR_rAcc.load().reduce( - cute.ReductionOp.MAX, row_max_new, 0 - ) - else: - row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) - - # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) - if cutlass.const_expr(self.warps_in_n == 2): - common_params.smem_exchange[tidx] = row_max_new - self.softmax_exchange_sync_bar.wait() - row_max_new = cute.arch.fmax( - row_max_new, - common_params.smem_exchange[ - (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) - ], - ) - - # find correction factor - correction_factor = cute.math.exp2( - (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True - ) - # split kv case - if cutlass.const_expr(not is_local_last_tile): - p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( - common_params, - softmax_params, - correction_factor, - row_sum, - row_max, - row_max_new, - tAcc, - tidx, - p_cor_producer_state, - ) - - # softmax - fma_b = softmax_params.softmax_scale_log2 - fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 - - for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): - tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c - tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) - - tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) - - # quantize - tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) - - # create sP - sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] - sP_mk_view = cute.make_tensor( - sP.iterator, - cute.make_layout( - ( - (sP.shape[0][0], sP.shape[1]), - (sP.shape[0][1], sP.shape[2], sP.shape[3]), - ), - stride=( - (sP.stride[0][0], sP.stride[1]), - (sP.stride[0][1], sP.stride[2], sP.stride[3]), - ), - ), - ) - # {$nv-internal-release begin} - # TODO: figure out if we could use A tmem for pv. - # {$nv-internal-release end} - # change to PISL - sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) - swizzle_bits = ( - int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 - ) - swizzle_base = 3 if self.q_dtype.width == 16 else 4 - sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) - sP_mk_view = cute.make_tensor( - sP_wo_swizzle_iter, - cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), - ) - universal_copy_bits = 128 - smem_copy_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=universal_copy_bits, - ) - smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) - smem_thr_copy = smem_tiled_copy.get_slice(tidx) - rP_copy_view = smem_thr_copy.retile(tTR_rS) - sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) - cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) - - # fence between smem store and mma o - cute.arch.fence_view_async_shared() - softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) - p_mma_producer_state.advance() - - # row_sum, using `add_packed_f32x2` to reduce the number of instructions - row_sum = row_sum * correction_factor - row_sum_vec = (0.0, 0.0) - for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): - row_sum_vec = cute.arch.add_packed_f32x2( - row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) - ) - row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum - - # split kv case - if cutlass.const_expr(is_local_last_tile): - p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( - common_params, - softmax_params, - correction_factor, - row_sum, - row_max, - row_max_new, - tAcc, - tidx, - p_cor_producer_state, - ) - - # store correction factor/row_sum/row_max to tmem for correction warp - common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - - # fence between tmem load and mma s - cute.arch.fence_view_async_tmem_load() - - softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) - mma_s_consumer_state.advance() - - return ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max_new, - row_sum, - correction_factor, - ) - - @cute.jit - def _tmem_load_partition( - self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int - ) -> tuple[ - cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma - ]: - """Tensor memory load partition for rescale and epilogue. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param iter_n: The iteration number - :type iter_n: int - - :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv - :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] - """ - - tOtO_shape = tiled_mma_pv.partition_shape_C( - cute.select(self.mma_pv_tiler, mode=[0, 1]) - ) - tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) - tOtO_layout = cute.append( - tOtO.layout, - cute.make_layout( - common_params.L // self.mma_pv_tiler[1], - stride=self.mma_pv_tiler[1] // self.warps_in_n, - ), - ) - tOtO = cute.make_tensor( - common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout - ) - tOtO = tOtO[None, None, None, iter_n] - - tAcc = tOtO[(None, None), 0, 0] - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype - ) - tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) - # {$nv-internal-release begin} - # TODO: supports size() on tiled copy. - # {$nv-internal-release end} - tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( - common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - ) - - cta_pv_tiler = ( - self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], - self.mma_pv_tiler[1], - self.mma_pv_tiler[2], - ) - # Flatten divide and partition global tensors for O - cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) - - gO = None - if cutlass.const_expr(common_params.mAccO is not None): - gO = cute.local_tile( - common_params.mAccO[None, common_params.blk_coord[3], None, None, None], - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - cO = cute.local_tile( - cute.make_identity_tensor( - common_params.mAccO[ - None, common_params.blk_coord[3], None, None, None - ].shape - ), - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - else: - gO = cute.local_tile( - common_params.mO, - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - cO = cute.local_tile( - cute.make_identity_tensor(common_params.mO.shape), - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) - tTR_gO = tmem_load_thr_copy.partition_D(gO) - tTR_cO = tmem_load_thr_copy.partition_D(cO) - tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) - return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] - - def get_correction_factor( - self, - common_params: SimpleNamespace, - p_cor_consumer_state: pipeline.PipelineState, - ) -> tuple[ - pipeline.PipelineState, - cutlass.Float32, - cutlass.Float32, - cutlass.Float32, - cutlass.Int32, - ]: - """Get the correction factor from the P correction consumer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param p_cor_consumer_state: The P correction consumer state - :type p_cor_consumer_state: pipeline.PipelineState - - :return: The P correction consumer state, the row_sum, the row_max, and the correction factor - :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] - """ - common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) - tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - # load correction factor - _, tAcc, _, _, _, _ = self._tmem_load_partition( - common_params, common_params.tiled_mma_pv, 0 - ) - corr_layout = cute.make_layout( - (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), - stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), - ) - tCor = cute.make_tensor( - common_params.tmem_ptr + self.correction_factor_offset, corr_layout - ) - cCor = cute.make_identity_tensor(tCor.shape) - corr_tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype - ) - corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) - corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) - tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) - cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) - rCor = cute.make_fragment_like( - cCor_for_copy[None, None, None, 0], self.acc_dtype - ) - rCor_int = cute.make_tensor( - cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout - ) - cute.copy( - corr_tmem_load_tiled_copy, - tCor_for_copy[None, None, None, p_cor_consumer_state.index], - rCor, - ) - row_sum = rCor[0] - row_max = rCor[1] - correction_factor = rCor[2] - no_correction = rCor_int[3] - - common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) - p_cor_consumer_state.advance() - return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction - - @cute.jit - def rescale( - self, - common_params: SimpleNamespace, - mma_o_consumer_state: pipeline.PipelineState, - correction_factor: cutlass.Float32, - no_correction: cutlass.Int32, - ) -> pipeline.PipelineState: - """Rescale for one k-tile. Updates the related pipeline state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param mma_o_consumer_state: The mma o consumer state - :type mma_o_consumer_state: pipeline.PipelineState - :param correction_factor: The correction factor - :type correction_factor: cutlass.Float32 - :param no_correction: Whether to apply correction factor - :type no_correction: cutlass.Int32 - - :return: The MMA o consumer state - :rtype: pipeline.PipelineState - """ - skip_correction = cute.arch.vote_all_sync(no_correction == 1) - common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) - if not skip_correction: - for iter_n in cutlass.range_constexpr(self.iterations_pv_n): - # tmem load tiled copy and partition results. - tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( - self._tmem_load_partition( - common_params, common_params.tiled_mma_pv, iter_n - ) - ) - - # tmem store tiled copy - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype - ) - tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) - - # load o - cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) - # rescale, using `mul_packed_f32x2` to reduce the number of instructions - for i in cutlass.range( - cute.size(tTR_rAcc), vectorize=True, unroll_full=True - ): - tTR_rAcc[i] = tTR_rAcc[i] * correction_factor - - # store o to tensor memory for next k tile - cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) - - cute.arch.fence_view_async_tmem_store() - common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) - mma_o_consumer_state.advance() - - return mma_o_consumer_state - - @cute.jit - def epilogue( - self, - common_params: SimpleNamespace, - epilogue_params: SimpleNamespace, - mma_o_consumer_state: pipeline.PipelineState, - row_sum: cutlass.Float32, - row_max: cutlass.Float32, - ) -> pipeline.PipelineState: - """Epilogue for one k-tile. Updates the related pipeline state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param epilogue_params: The epilogue parameters - :type epilogue_params: SimpleNamespace - :param mma_o_consumer_state: The mma o consumer state - :type mma_o_consumer_state: pipeline.PipelineState - :param row_sum: The row sum - :type row_sum: cutlass.Float32 - :param row_max: The row max - :type row_max: cutlass.Float32 - - :return: The MMA o consumer state - :rtype: pipeline.PipelineState - """ - - tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - - # exchange row_sum between warps (0, 1) and (2, 3) - if cutlass.const_expr(self.warps_in_n == 2): - common_params.smem_exchange[tidx] = row_sum - self.epilogue_exchange_sync_bar.wait() - # (64, 2) - row_sum = ( - row_sum - + common_params.smem_exchange[ - (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) - ] - ) - # mma_o pipeline consumer wait - common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) - for iter_n in cutlass.range_constexpr(self.iterations_pv_n): - # tmem load tiled copy and partition results. - tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( - self._tmem_load_partition( - common_params, common_params.tiled_mma_pv, iter_n - ) - ) - - # load o - cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) - - # apply output scale and normalize by row_sum - for i in cutlass.range( - cute.size(tTR_rAcc), vectorize=True, unroll_full=True - ): - tTR_rAcc[i] = ( - tTR_rAcc[i] - * epilogue_params.output_scale - * cute.arch.rcp_approx(row_sum) - ) - - # store o to global memory - tR2G_rO_src = None - tR2G_rO_dst = tTR_gO - if cutlass.const_expr(common_params.mAccO is None): - tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) - # using final output dtype for o - tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) - else: - # using accumulate dtype for o - tR2G_rO_src = tTR_rAcc - - if cute.elem_less(tTR_cO[0][0], common_params.H): - cute.autovec_copy( - tR2G_rO_src, - tR2G_rO_dst, - l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, - ) - - # store the lse to global memory - cta_pv_tiler = ( - self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], - self.mma_pv_tiler[1], - self.mma_pv_tiler[2], - ) - gLSE = None - cLSE = None - if cutlass.const_expr(epilogue_params.mAccLSE is None): - gLSE = cute.local_tile( - epilogue_params.mLSE, - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - cLSE = cute.local_tile( - cute.make_identity_tensor(epilogue_params.mLSE.shape), - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - - else: - gLSE = cute.local_tile( - epilogue_params.mAccLSE[ - None, common_params.blk_coord[3], None, None - ], - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - cLSE = cute.local_tile( - cute.make_identity_tensor( - epilogue_params.mAccLSE[ - None, common_params.blk_coord[3], None, None - ].shape - ), - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - lse = ( - cute.math.log2(row_sum, fastmath=True) - + epilogue_params.softmax_scale_log2 * row_max - ) - if cutlass.const_expr(self.warps_in_n == 2): - if cute.elem_less(cLSE[tidx][0], common_params.H): - gLSE[tidx] = lse - - cute.arch.fence_view_async_tmem_load() - common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) - mma_o_consumer_state.advance() - - return mma_o_consumer_state - - def make_and_init_load_pt_pipeline(self, load_pt_mbar_ptr): - """Create and initialize the load page table pipeline. - - :param load_pt_mbar_ptr: The load page table mbar pointer - :type load_pt_mbar_ptr: cute.Tensor - - :return: The load page table pipeline - :rtype: pipeline.PipelineAsync - """ - load_pt_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len([self.load_pt_warp_id]), - ) - load_pt_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - self.threads_per_warp * len([self.load_tma_warp_id]), - ) - return pipeline.PipelineCpAsync.create( - barrier_storage=load_pt_mbar_ptr, - num_stages=self.load_pt_stage, - producer_group=load_pt_producer_group, - consumer_group=load_pt_consumer_group, - defer_sync=True, - ) - - def make_and_init_load_qkv_pipeline( - self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count - ) -> pipeline.PipelineTmaUmma: - """Create and initialize the tma load qkv pipeline. - - :param load_qkv_mbar_ptr: The load qkv mbar pointer - :type load_qkv_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - :param load_stages: The load stages - :type load_stages: list[int] - :param tx_count: The tx count - :type tx_count: int - - :return: The tma load qkv pipeline - :rtype: pipeline.PipelineTmaUmma - """ - load_qkv_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.load_tma_warp_id]) - ) - load_qkv_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - return pipeline.PipelineTmaUmma.create( - barrier_storage=load_qkv_mbar_ptr, - num_stages=load_stages, - producer_group=load_qkv_producer_group, - consumer_group=load_qkv_consumer_group, - tx_count=tx_count, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - def make_and_init_mma_s_pipeline( - self, mma_s_mbar_ptr, cta_layout_vmnk - ) -> pipeline.PipelineUmmaAsync: - """Create and initialize the mma s pipeline. - - :param mma_s_mbar_ptr: The mma s mbar pointer - :type mma_s_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - - :return: The mma s pipeline - :rtype: pipeline.PipelineUmmaAsync - """ - - mma_s_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - consumer_thread_size = ( - self.threads_per_warp - * len(self.compute_warp_ids) - * self.cluster_shape_mnk[0] - ) - mma_s_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - consumer_thread_size, - ) - return pipeline.PipelineUmmaAsync.create( - barrier_storage=mma_s_mbar_ptr, - num_stages=self.mma_s_stage, - producer_group=mma_s_producer_group, - consumer_group=mma_s_consumer_group, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - def make_and_init_p_mma_pipeline( - self, p_mma_mbar_ptr, cta_layout_vmnk - ) -> pipeline.PipelineAsyncUmma: - """Create and initialize the p mma pipeline. - - :param p_mma_mbar_ptr: The p mma mbar pointer - :type p_mma_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - - :return: The p mma pipeline - :rtype: pipeline.PipelineAsyncUmma - """ - - producer_thread_size = ( - self.threads_per_warp - * len(self.compute_warp_ids) - * self.cluster_shape_mnk[0] - ) - p_mma_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - producer_thread_size, - ) - p_mma_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - return pipeline.PipelineAsyncUmma.create( - barrier_storage=p_mma_mbar_ptr, - num_stages=self.p_mma_stage, - producer_group=p_mma_producer_group, - consumer_group=p_mma_consumer_group, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - def make_and_init_p_cor_pipeline( - self, p_cor_mbar_ptr - ) -> pipeline.PipelineAsyncUmma: - """Create and initialize the p correction pipeline. - - :param p_cor_mbar_ptr: The p correction mbar pointer - :type p_cor_mbar_ptr: cute.Tensor - - :return: The p correction pipeline - :rtype: pipeline.PipelineAsyncUmma - """ - - producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) - p_cor_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - producer_thread_size, - ) - p_cor_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - producer_thread_size, - ) - return pipeline.PipelineAsync.create( - barrier_storage=p_cor_mbar_ptr, - num_stages=self.p_cor_stage, - producer_group=p_cor_producer_group, - consumer_group=p_cor_consumer_group, - defer_sync=True, - ) - - def make_and_init_mma_o_pipeline( - self, mma_o_mbar_ptr, cta_layout_vmnk - ) -> pipeline.PipelineUmmaAsync: - """Create and initialize the mma o pipeline. - - :param mma_o_mbar_ptr: The mma o mbar pointer - :type mma_o_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - - :return: The mma o pipeline - :rtype: pipeline.PipelineUmmaAsync - """ - - mma_o_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - consumer_thread_size = ( - self.threads_per_warp - * len(self.compute_warp_ids) - * self.cluster_shape_mnk[0] - ) - mma_o_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - consumer_thread_size, - ) - return pipeline.PipelineUmmaAsync.create( - barrier_storage=mma_o_mbar_ptr, - num_stages=self.mma_o_stage, - producer_group=mma_o_producer_group, - consumer_group=mma_o_consumer_group, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - @staticmethod - def _compute_grid( - o: cute.Tensor, - split_kv: cutlass.Int32, - cluster_shape_mnk: Tuple[int, int, int], - max_active_clusters: int, - is_persistent: bool, - ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: - """Compute grid shape for the output tensor C. - - :param c: The output tensor C - :type c: cute.Tensor - :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. - :type cta_tile_shape_mnk: tuple[int, int, int] - :param cluster_shape_mn: Shape of each cluster in M, N dimensions. - :type cluster_shape_mn: tuple[int, int] - - :return: Tile scheduler parameters and grid shape. - :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] - """ - o_shape = o.shape - tile_sched_params = create_mla_static_tile_scheduler_params( - is_persistent, - cute.size(o_shape[3]), - cute.size(o_shape[2]), - cluster_shape_mnk, - split_kv, - ) - grid = MLAStaticTileScheduler.get_grid_shape( - tile_sched_params, max_active_clusters - ) - - return tile_sched_params, grid - - @staticmethod - def get_workspace_size( - H: int, - S: int, - D: int, - B: int, - split_kv: int, - acc_dtype: Type[cutlass.Numeric], - ) -> int: - """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. - - :param H: The height of the output tensor C - :type H: int - :param S: The sequence length of the output tensor C - :type S: int - :param D: The depth of the output tensor C - :type D: int - :param B: The batch size of the output tensor C - :type B: int - :param split_kv: The split key-value of the output tensor C - :type split_kv: int - :param acc_dtype: The data type of the output tensor C - :type acc_dtype: Type[cutlass.Numeric] - - :return: The workspace size for the MLA kernel - :rtype: int - """ - if split_kv == 1: - return 0 - return B * H * S * split_kv * (D + 1) * acc_dtype.width // 8 - - @cute.jit - def initialize_workspace( - self, - H: cutlass.Int32, - D: cutlass.Int32, - S: cutlass.Int32, - B: cutlass.Int32, - split_kv: cutlass.Int32, - acc_dtype: Type[cutlass.Numeric], - workspace: cute.Tensor, - ) -> tuple[cute.Tensor, cute.Tensor]: - """Initialize the workspace for the MLA kernel. Construct the intermediate tensors - acc_o and acc_lse. - - :param H: The height of the output tensor C - :type H: cutlass.Int32 - :param D: The depth of the output tensor C - :type D: cutlass.Int32 - :param S: The sequence length of the output tensor C - :type S: cutlass.Int32 - :param B: The batch size of the output tensor C - :type B: cutlass.Int32 - :param split_kv: The split key-value of the output tensor C - :type split_kv: cutlass.Int32 - :param acc_dtype: The data type of the output tensor C - :type acc_dtype: Type[cutlass.Numeric] - :param workspace: The workspace tensor - :type workspace: cute.Tensor - - :return: The output tensor C and the workspace tensor - :rtype: tuple[cute.Tensor, cute.Tensor] - """ - acc_o, acc_lse = None, None - if cutlass.const_expr(workspace is not None): - align = 256 // self.q_dtype.width - acc_o_layout = cute.make_layout( - (H, split_kv, D, S, B), - stride=( - cute.assume(split_kv * D, align), - cute.assume(D, align), - 1, - cute.assume(split_kv * H * D, align), - cute.assume(H * split_kv * S * D, align), - ), - ) - acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) - acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) - acc_lse_layout = cute.make_layout( - (H, split_kv, S, B), - stride=(split_kv, 1, H * split_kv, H * split_kv * S), - ) - acc_lse_iter = cute.recast_ptr( - workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, - dtype=acc_dtype, - ) - acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) - return acc_o, acc_lse - - @staticmethod - def can_implement( - B: int, - S: int, - K: int, - H: int, - L: int, - R: int, - in_dtype: Type[cutlass.Numeric], - out_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - lse_dtype: Type[cutlass.Numeric], - mma_qk_tiler_mn: Tuple[int, int], - mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - page_size: int, - ) -> bool: - """Check if the MLA kernel can be implemented. - - :param B: The batch size of the output tensor C - :type B: int - :param S: The sequence length of the output tensor C - :type S: int - :param K: The width of the output tensor KV - :type K: int - :param H: The number of heads of the output tensor C - :type H: int - :param L: The number of latent dimensions of the tensor KV - :type L: int - :param R: The number of rope dimensions of the tensor C_rope - :type R: int - :param in_dtype: The data type of the input tensor - :type in_dtype: Type[cutlass.Numeric] - :param out_dtype: The data type of the output tensor - :type out_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] - :param lse_dtype: The data type of the log-sum-exp - :type lse_dtype: Type[cutlass.Numeric] - :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication - :type mma_qk_tiler_mn: Tuple[int, int] - :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication - :type mma_pv_tiler_mn: Tuple[int, int] - :param split_kv: The split key-value of the output tensor C - :type split_kv: int - :param is_persistent: Whether to use persistent kernel optimization - :type is_persistent: bool - :param is_var_seq: Whether to use variable sequence length - :type is_var_seq: bool - :param is_var_split_kv: Whether to use variable split_kv - :type is_var_split_kv: bool - :param page_size: The page size of the page table - :type page_size: int - - :return: Whether the MLA kernel can be implemented - :rtype: bool - """ - if L != 512 or R != 64: - return False - if in_dtype not in [cutlass.Float16, cutlass.BFloat16]: - return False - if out_dtype not in [cutlass.Float16, cutlass.BFloat16]: - return False - if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: - return False - # page size equals 1 is prohibited by tma specification, not 128B aligned. - if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: - return False - if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: - return False - if is_var_split_kv and not is_var_seq: - return False - if H > 128 or (H < 128 and split_kv != 1): - return False - if S < 1 or S > 4: - return False - if K <= 0: - return False - return True - - -def run( - batch_size: int, - seq_len_q: int, - seq_len_k: int, - num_heads: int, - latent_dim: int, - rope_dim: int, - in_dtype: Type[cutlass.Numeric], - out_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - lse_dtype: Type[cutlass.Numeric], - mma_qk_tiler_mn: Tuple[int, int], - mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - page_size: int, - softmax_scale: float, - output_scale: float, - skip_correction_threshold: float, - tolerance: float, - warmup_iterations: int, - iterations: int, - skip_ref_check: bool, - use_cold_l2: bool, - enable_pdl: bool = False, - **kwargs, -): - """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. - - This function creates random input tensors for query latent/rope, compressed latent/rope, and value, - then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, - page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference - implementation or run multiple times for performance measurement. - - :param batch_size: Batch size - :type batch_size: int - :param seq_len_q: Sequence length of Q - :type seq_len_q: int - :param seq_len_k: Sequence length of K - :type seq_len_k: int - :param num_heads: Number of heads - :type num_heads: int - :param latent_dim: dimension of query/compressed latent - :type latent_dim: int - :param rope_dim: dimension of query/compressed rope - :type rope_dim: int - :param in_dtype: Input data type for query/compressed latent/rope tensors - :type in_dtype: Type[cutlass.Numeric] - :param out_dtype: Output data type for attention output - :type out_dtype: Type[cutlass.Numeric] - :param acc_dtype: Accumulator data type for query-key matrix multiplication - :type acc_dtype: Type[cutlass.Numeric] - :param lse_dtype: Accumulator data type for log-sum-exp - :type lse_dtype: Type[cutlass.Numeric] - :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication - :type mma_qk_tiler_mn: Tuple[int, int] - :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication - :type mma_pv_tiler_mn: Tuple[int, int] - :param split_kv: Split key-value - :type split_kv: int - :param is_persistent: Whether to use persistent kernel optimization - :type is_persistent: bool - :param is_var_seq: Whether to use variable sequence length - :type is_var_seq: bool - :param is_var_split_kv: Whether to use variable split_kv - :type is_var_split_kv: bool - :param page_size: Page size of the page table - :type page_size: int - :param softmax_scale: Attention score scaling factor - :type softmax_scale: float - :param output_scale: Output scaling factor - :type output_scale: float - :param skip_correction_threshold: Threshold to skip correction - :type skip_correction_threshold: float - :param tolerance: Maximum acceptable error for validation - :type tolerance: float - :param warmup_iterations: Number of warmup iterations - :type warmup_iterations: int - :param iterations: Number of iterations to run for performance testing - :type iterations: int - :param skip_ref_check: Skip validation against reference implementation - :type skip_ref_check: bool - :param use_cold_l2: Whether to use cold L2 cache - :type use_cold_l2: bool - - :raises ValueError: If input shapes are incompatible or head dimension is unsupported - :raises RuntimeError: If GPU is unavailable for computation - """ - - print("Running Blackwell MLA test with:") - print(f" batch_size: {batch_size}") - print(f" seq_len_q: {seq_len_q}") - print(f" seq_len_k: {seq_len_k}") - print(f" num_heads: {num_heads}") - print(f" latent_dim: {latent_dim}") - print(f" rope_dim: {rope_dim}") - print(f" in_dtype: {in_dtype}") - print(f" out_dtype: {out_dtype}") - print(f" acc_dtype: {acc_dtype}") - print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") - print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") - print(f" split_kv: {split_kv}") - print(f" is_persistent: {is_persistent}") - print(f" is_var_seq: {is_var_seq}") - print(f" is_var_split_kv: {is_var_split_kv}") - print(f" page_size: {page_size}") - print(f" softmax_scale: {softmax_scale}") - print(f" output_scale: {output_scale}") - print(f" skip_correction_threshold: {skip_correction_threshold}") - print(f" tolerance: {tolerance}") - print(f" warmup_iterations: {warmup_iterations}") - print(f" iterations: {iterations}") - print(f" skip_ref_check: {skip_ref_check}") - print(f" use_cold_l2: {use_cold_l2}") - - # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) - if not torch.cuda.is_available(): - raise RuntimeError("GPU is required to run this example!") - - if not BlackwellMultiHeadLatentAttentionForwardFP16.can_implement( - batch_size, - seq_len_q, - seq_len_k, - num_heads, - latent_dim, - rope_dim, - in_dtype, - out_dtype, - acc_dtype, - lse_dtype, - mma_qk_tiler_mn, - mma_pv_tiler_mn, - split_kv, - is_persistent, - is_var_seq, - is_var_split_kv, - page_size, - ): - raise TypeError( - f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" - ) - - torch.manual_seed(1111) - - def create_data_tensor( - B, - HK, - D, - dtype, - is_dynamic_layout=True, - page_table=None, - cache_seqs=None, - is_lse=False, - seq_len_q=None, - ): - shape = (B, HK, D) - if page_table is not None: - if cache_seqs is not None: - max_seq_len = torch.max(cache_seqs) - shape = (B * ceil_div(max_seq_len, page_size), page_size, D) - else: - shape = (B * ceil_div(HK, page_size), page_size, D) - - if seq_len_q is not None: - shape = (B, seq_len_q, HK, D) - - # Contiguous row-major: last dim has stride 1 (highest stride_order value = fastest) - if is_lse: - shape = (B, seq_len_q, HK) - leading_dim = 2 - stride_order = (0, 1, 2) - elif seq_len_q is not None: - leading_dim = 3 - stride_order = (0, 1, 2, 3) - else: - leading_dim = 2 - stride_order = (0, 1, 2) - - init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) - - torch_dtype = ( - cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 - ) - - # Create contiguous dtype torch tensor (cpu) — no permute - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - init_type=cutlass.torch.TensorInitType.RANDOM, - init_config=init_config, - ) - - # Create dtype torch tensor (gpu) - torch_tensor_gpu = torch_tensor_cpu.cuda() - - # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) - if not is_lse: - cute_tensor = cute_tensor.mark_compact_shape_dynamic( - mode=leading_dim, - stride_order=stride_order, - divisibility=(128 // dtype.width), - ) - - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, - ) - - return f32_torch_tensor, cute_tensor, torch_tensor_gpu - - def create_cache_seqs(batch_size, seq_len_k, is_var_seq): - cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k - cache_seqs_gpu = cache_seqs_ref.cuda() - cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() - if is_var_seq: - max_seq_len = seq_len_k - min_seq_len = int(seq_len_k * 0.8) - cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( - (batch_size,), - torch.int32, - init_type=cutlass.torch.TensorInitType.RANDOM, - init_config=cutlass.torch.RandomInitConfig( - min_val=min_seq_len, max_val=max_seq_len + 1 - ), - ) - cache_seqs_gpu = cache_seqs_ref.cuda() - cache_seqs = from_dlpack( - cache_seqs_gpu, - assumed_align=16, - ).mark_layout_dynamic() - return cache_seqs_ref, cache_seqs, cache_seqs_gpu - - def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): - max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) - page_count = ceil_div(max_seq_len, page_size) - page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) - # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. - for b in range(batch_size): - for j in range(page_count): - page_table_ref[b, j] = b + j * batch_size - page_table_gpu = page_table_ref.cuda() # contiguous [B, page_count] - page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( - leading_dim=1 - ) - return page_table_ref, page_table, page_table_gpu - - def create_block_split_kvs( - batch_size, - split_kv, - cache_seqs_ref, - is_var_split_kv, - mma_qk_tiler_mn, - cluster_shape_mnk, - max_active_clusters, - ): - block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None - # check if split_kv is valid otherwise do auto setting of split_kv - if is_var_split_kv: - block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) - for b in range(batch_size): - block_split_kvs_ref[b] = ( - BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( - batch_size, - seq_len_q, - cache_seqs_ref[b].item(), - mma_qk_tiler_mn, - max_active_clusters * cluster_shape_mnk[0], - ) - ) - split_kv = torch.max(block_split_kvs_ref).item() - block_split_kvs_gpu = block_split_kvs_ref.cuda() - block_split_kvs = from_dlpack( - block_split_kvs_gpu, assumed_align=16 - ).mark_layout_dynamic() - elif split_kv <= 0: - split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( - batch_size, - seq_len_q, - cache_seqs_ref[0].item(), - mma_qk_tiler_mn, - max_active_clusters * cluster_shape_mnk[0], - ) - return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu - - def create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype - ): - workspace_size = ( - BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( - num_heads, - seq_len_q, - latent_dim, - batch_size, - split_kv, - acc_dtype, - ) - ) - - workspace, workspace_torch = None, None - if workspace_size > 0: - workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() - workspace = from_dlpack(workspace_torch, assumed_align=32) - return workspace, workspace_torch - - cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( - batch_size, seq_len_k, is_var_seq - ) - page_table_ref, page_table, page_table_torch = create_page_table( - batch_size, seq_len_k, is_var_seq, page_size - ) - cluster_shape_mnk = (2, 1, 1) - hardware_info = utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mnk[0] * cluster_shape_mnk[1] - ) - split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( - create_block_split_kvs( - batch_size, - split_kv, - cache_seqs_ref, - is_var_split_kv, - mma_qk_tiler_mn, - cluster_shape_mnk, - max_active_clusters, - ) - ) - - q_latent_ref, q_latent, q_latent_torch = create_data_tensor( - batch_size, - num_heads, - latent_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - q_rope_ref, q_rope, q_rope_torch = create_data_tensor( - batch_size, - num_heads, - rope_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - - c_latent_ref, c_latent, c_latent_torch = create_data_tensor( - batch_size, - seq_len_k, - latent_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - c_rope_ref, c_rope, c_rope_torch = create_data_tensor( - batch_size, - seq_len_k, - rope_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - o_ref, o, o_torch = create_data_tensor( - batch_size, - num_heads, - latent_dim, - out_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - lse_ref, lse, lse_torch = create_data_tensor( - batch_size, - num_heads, - 1, - lse_dtype, - is_dynamic_layout=True, - is_lse=True, - seq_len_q=seq_len_q, - ) - workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype - ) - - mla = BlackwellMultiHeadLatentAttentionForwardFP16( - acc_dtype, - lse_dtype, - mma_qk_tiler_mn, - mma_pv_tiler_mn, - max_active_clusters, - page_size, - skip_correction_threshold, - is_persistent, - is_var_seq, - is_var_split_kv, - enable_pdl, - ) - - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - stream = cuda.CUstream(torch_stream.cuda_stream) - - # compile mla kernel - compiled_mla = cute.compile( - mla, - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - split_kv, - cache_seqs, - block_split_kvs, - softmax_scale, - output_scale, - stream, - options="--opt-level 2", - ) - - def torch_reference_mla( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - cache_seqs, - softmax_scale=1.0, - output_scale=1.0, - ): - # Ref tensors are now contiguous: - # q_latent/q_rope: [B, S_q, H, D] - # c_latent/c_rope: [num_pages, page_size, D] - # Concat along last dim and reshape for SDPA [B, S_q, H, D_total] - q_ref = torch.cat([q_latent, q_rope], dim=3) - # KV cache: concat along last dim, already [num_pages, page_size, D_total] - page_count = page_table_ref.shape[1] - k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( - batch_size * page_count, page_size, latent_dim + rope_dim - ) - v_ref_paged = c_latent.reshape(batch_size * page_count, page_size, latent_dim) - - if is_var_seq: - max_seq_len = torch.max(cache_seqs_ref) - else: - max_seq_len = seq_len_k - - k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) - v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) - k_ref = torch.index_select( - k_ref_paged, 0, torch.flatten(page_table_ref) - ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] - v_ref = torch.index_select( - v_ref_paged, 0, torch.flatten(page_table_ref) - ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] - for b in range(batch_size): - k_ref[b, :, cache_seqs_ref[b] :, :] = 0 - v_ref[b, :, cache_seqs_ref[b] :, :] = 0 - import torch.nn.functional as F - - o_ref = F.scaled_dot_product_attention( - q_ref, - k_ref, - v_ref, - attn_mask=None, - dropout_p=0.0, - scale=softmax_scale, - is_causal=False, - ) - s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) - s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) - softmax_scale_log2 = LOG2_E * softmax_scale - s_ref_sum = torch.sum( - torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True - ) - - lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) - lse_ref = lse_ref.squeeze(3) # [B, S_q, H] - o_ref = o_ref * output_scale - # o_ref already [B, S_q, H, D_latent] — matches contiguous output layout - - return o_ref, lse_ref - - if skip_correction_threshold > 0.0: - print( - "Skipping correction verification since skip_correction_threshold is greater than 0.0..." - ) - skip_ref_check = True - if not skip_ref_check: - # Execute kernel once for reference checking - compiled_mla( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - split_kv, - cache_seqs, - block_split_kvs, - softmax_scale, - output_scale, - stream, - ) - torch.cuda.synchronize() - - print("Verifying results...") - if in_dtype == cutlass.Float8E4M3FN: - tolerance = 0.13 - o_ref, lse_ref = torch_reference_mla( - q_latent_ref, - q_rope_ref, - c_latent_ref, - c_rope_ref, - page_table, - cache_seqs, - softmax_scale, - output_scale, - ) - - if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: - # {$nv-internal-release begin} - # todo: not sure why, but the below `cute.testing.convert` will cause bus error occasionally in local and ci. - # {$nv-internal-release end} - # convert o back to f32 for comparison - o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( - torch.empty(*o_torch.shape, dtype=torch.float32), - cutlass.Float32, - is_dynamic_layout=True, - assumed_align=16, - ) - cute.testing.convert(o, o_fp32) - o = o_fp32_torch.cpu() - ref_fp8, _ = cutlass_torch.cute_tensor_like( - torch.empty(*o_ref.shape, dtype=torch.uint8), - out_dtype, - is_dynamic_layout=True, - assumed_align=16, - ) - o_ref_gpu = o_ref.cuda() - o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=3) - - # convert ref : f32 -> fp8 -> f32 - cute.testing.convert(o_ref_f32, ref_fp8) - cute.testing.convert(ref_fp8, o_ref_f32) - - o_ref = o_ref_gpu.cpu() - else: - o = o_torch.cpu().to(torch.float32) - lse = lse_torch.cpu() - lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) - # Assert close results - torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) - torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) - print("Results verified successfully!") - - def generate_tensors(): - _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) - _, page_table, _ = create_page_table( - batch_size, seq_len_k, is_var_seq, page_size - ) - _split_kv, _, block_split_kvs, _ = create_block_split_kvs( - batch_size, - split_kv, - cache_seqs_ref, - is_var_split_kv, - mma_qk_tiler_mn, - cluster_shape_mnk, - max_active_clusters, - ) - - _, q_latent, _ = create_data_tensor( - batch_size, - num_heads, - latent_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - _, q_rope, _ = create_data_tensor( - batch_size, - num_heads, - rope_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - - _, c_latent, _ = create_data_tensor( - batch_size, - seq_len_k, - latent_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - _, c_rope, _ = create_data_tensor( - batch_size, - seq_len_k, - rope_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - _, o, _ = create_data_tensor( - batch_size, - num_heads, - latent_dim, - out_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - _, lse, _ = create_data_tensor( - batch_size, - num_heads, - 1, - lse_dtype, - is_dynamic_layout=True, - is_lse=True, - seq_len_q=seq_len_q, - ) - workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype - ) - return testing.JitArguments( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - _split_kv, - cache_seqs, - block_split_kvs, - softmax_scale, - output_scale, - stream, - ) - - workspace_count = 1 - if use_cold_l2: - one_workspace_bytes = ( - q_latent_torch.numel() * q_latent_torch.element_size() - + q_rope_torch.numel() * q_rope_torch.element_size() - + c_latent_torch.numel() * c_latent_torch.element_size() - + c_rope_torch.numel() * c_rope_torch.element_size() - + o_torch.numel() * o_torch.element_size() - + lse_torch.numel() * lse_torch.element_size() - + cache_seqs_torch.numel() * cache_seqs_torch.element_size() - ) - one_workspace_bytes += ( - page_table_torch.numel() * page_table_torch.element_size() - ) - if is_var_split_kv: - one_workspace_bytes += ( - block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() - ) - if workspace_torch is not None: - one_workspace_bytes += ( - workspace_torch.numel() * workspace_torch.element_size() - ) - workspace_count = testing.get_workspace_count( - one_workspace_bytes, warmup_iterations, iterations - ) - - avg_time_us = testing.benchmark( - compiled_mla, - workspace_generator=generate_tensors, - workspace_count=workspace_count, - stream=stream, - warmup_iterations=warmup_iterations, - iterations=iterations, - ) - - return avg_time_us # Return execution time in microseconds diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py deleted file mode 100644 index 638cc8a5b0..0000000000 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ /dev/null @@ -1,4221 +0,0 @@ -# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -import math -from typing import Type, Tuple, Optional -from types import SimpleNamespace - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute -import cutlass.cute.testing as testing -from cutlass.cute.nvgpu import tcgen05 -from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode - -# Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; -# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. -_setmaxregister_decrease = getattr( - cute.arch, - "setmaxregister_decrease", - getattr(cute.arch, "warpgroup_reg_dealloc", None), -) -_setmaxregister_increase = getattr( - cute.arch, - "setmaxregister_increase", - getattr(cute.arch, "warpgroup_reg_alloc", None), -) - -# Compat shim: get_max_tmem_alloc_cols added in cutlass-dsl 4.4; -# older versions don't have it, so we provide a fallback implementation. -_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} - - -def _get_max_tmem_alloc_cols(compute_capability: str) -> int: - if hasattr(cute.arch, "get_max_tmem_alloc_cols"): - return cute.arch.get_max_tmem_alloc_cols(compute_capability) - if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: - raise ValueError(f"Unsupported compute capability: {compute_capability}") - return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] - - -import cutlass.cute.nvgpu.cpasync as cpasync -import cutlass.utils as utils -import cutlass.pipeline as pipeline -from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait -import cutlass.utils.blackwell_helpers as sm100_utils -from cutlass.cute.runtime import from_dlpack -from cutlass.cute.arch import Arch -from cutlass.cutlass_dsl import BaseDSL - - -from .mla_helpers import ( - ceil_div, - MAX_SPLITS, - LOG2_E, - MLAStaticTileScheduler, - MLAStaticTileSchedulerParams, - create_mla_static_tile_scheduler, - create_mla_static_tile_scheduler_params, -) - -""" -A Multi-Head Latent Attention (MLA) example using fp8 as input/output for the NVIDIA Blackwell SM100 architecture using CUTE DSL - -This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell -SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T -matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. -The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting -functionality to minimize latency when processing long KV sequences. - -The kernel implements key optimizations including: -- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) -- Pipeline stages between different warps for overlapping computation and memory access -- Support for different precision data types -- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing - -To run this example: - -.. code-block:: bash - - python examples/blackwell/mla_fp8.py \ - --batch_size 4 --latent_dim 512 --rope_dim 64 \ - --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ - --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ - --acc_dtype Float32 --lse_dtype Float32 \ - --is_var_seq --is_var_split_kv \ - --is_persistent - -The above example runs Multi-Head Latent Attention (MLA) with the following configuration: -- Batch size: 4 -- Sequence length of Q: 1 -- Sequence length of K: 1024 -- Latent dimension: 512 -- RoPE dimension: 64 -- Number of heads: 128 -- Data types: Float8E4M3FN (input), Float8E4M3FN (output), Float32 (accumulation and LSE) - -It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences -and variable split KV processing with persistent scheduling. - -To collect performance with NCU profiler: - -.. code-block:: bash - - ncu python examples/blackwell/mla_fp8.py \ - --batch_size 4 --latent_dim 512 --rope_dim 64 \ - --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ - --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ - --acc_dtype Float32 --lse_dtype Float32 \ - --is_var_seq --is_var_split_kv \ - --is_persistent --warmup_iterations 3 \ - --iterations 10 --skip_ref_check - -Constraints for this example: -* Data type requirements: - - Input/output: Float8E4M3FN - - Accumulation and LSE: Float32 -* Fixed architecture parameters: - - Number of attention heads: 128 - - Latent dimension: 512 - - RoPE dimension: 64 -* Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) -* Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) -* Query sequence length must be 1-4 -* Only supports 2-CTA instructions -* Variable sequence length requires page table storage enabled -""" - - -class BlackwellMultiHeadLatentAttentionForwardFP8: - def __init__( - self, - acc_dtype: Type[cutlass.Numeric], - lse_dtype: Type[cutlass.Numeric], - mma_qk_tiler_mn: Tuple[int, int], - mma_pv_tiler_mn: Tuple[int, int], - max_active_clusters: int, - page_size: int, - skip_correction_threshold: float, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - enable_pdl: bool, - ): - """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. - - :param acc_dtype: Data type for accumulation S and O - :type acc_dtype: Type[cutlass.Numeric] - :param lse_dtype: Data type for output LSE - :type lse_dtype: Type[cutlass.Numeric] - :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S - :type mma_s_tiler: Tuple[int, int] - :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P - :type mma_p_tiler: Tuple[int, int] - :param max_active_clusters: Maximum number of active clusters - :type max_active_clusters: int - :param page_size: The page size - :type page_size: int - :param skip_correction_threshold: Threshold to skip correction - :type skip_correction_threshold: float - :param is_persistent: Whether to use persistent kernel mode - :type is_persistent: bool - :param is_var_seq: Whether to use variable sequence length - :type is_var_seq: bool - :param is_var_split_kv: Whether to use variable split KV - :type is_var_split_kv: bool - :param enable_pdl: Whether to use PDL - :type enable_pdl: bool - """ - - self.latent_dim = 512 - self.rope_dim = 64 - self.acc_dtype = acc_dtype - self.lse_dtype = lse_dtype - self.mma_qk_tiler_mn = mma_qk_tiler_mn - self.mma_pv_tiler_mn = mma_pv_tiler_mn - self.max_active_clusters = max_active_clusters - self.skip_correction_threshold = skip_correction_threshold - self.is_persistent = is_persistent - self.page_size = page_size - self.is_var_seq = is_var_seq - self.is_var_split_kv = is_var_split_kv - self.enable_pdl = enable_pdl - self.cluster_shape_mnk = (2, 1, 1) - self.use_2cta_instrs = True - # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), - # while warps 2-3 handle accumulation for second half [n/2, n) - self.warps_in_n = 2 - self.num_compute_warps = 4 - self.threads_per_warp = 32 - mma_qk_tiler_k = self.rope_dim * 2 - self.mma_qk_tiler = ( - self.mma_qk_tiler_mn[0], - self.mma_qk_tiler_mn[1], - mma_qk_tiler_k, - ) - self.mma_qk_rope_tiler = ( - self.mma_qk_tiler_mn[0], - self.mma_qk_tiler_mn[1], - self.rope_dim, - ) - self.mma_pv_tiler = ( - self.mma_pv_tiler_mn[0], - self.mma_pv_tiler_mn[1], - self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], - ) - self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] - self.iterations_qk_rope = 1 - self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope - self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] - self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] - - # Set specialized warp ids - self.compute_warp_ids = (0, 1, 2, 3) - self.correction_warp_ids = (4, 5, 6, 7) - self.mma_warp_id = 8 - self.load_tma_k_warp_id = 9 - self.load_tma_v_warp_id = 10 - self.empty_warp_ids = (11,) - self.threads_per_cta = self.threads_per_warp * len( - ( - self.mma_warp_id, - self.load_tma_k_warp_id, - self.load_tma_v_warp_id, - *self.compute_warp_ids, - *self.correction_warp_ids, - *self.empty_warp_ids, - ) - ) - - # register settings - self.softmax_reg_num = 192 - self.correction_reg_num = 256 - self.other_reg_num = 48 - # Named barriers - self.tmem_ptr_sync_bar = pipeline.NamedBarrier( - barrier_id=1, - num_threads=( - self.threads_per_warp - + self.threads_per_warp * self.num_compute_warps * 2 - ), - ) - self.softmax_exchange_sync_bar = pipeline.NamedBarrier( - barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) - ) - self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( - barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) - ) - - def _setup_attributes(self): - """Set up configurations and parameters for the MLA kernel operation. - - This method initializes and configures various attributes required for the - execution of the multi-head latent attention kernel, mainly about the pipeline stages: - - - Sets up staging parameters for Q, K, V inputs and accumulator data - - Configures pipeline stages for softmax, correction, and epilogue operations - """ - - self.load_q_stage = 1 - self.load_k_stage = 3 - self.load_v_stage = 2 - self.mma_s_stage = 2 - self.p_mma_stage = 2 - self.p_cor_stage = 2 - self.mma_o_stage = 2 - - self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n - self.correction_factor_offset = ( - self.tmem_o_offset + self.latent_dim // self.warps_in_n - ) - - @cute.jit - def __call__( - self, - q_latent: cute.Tensor, - q_rope: cute.Tensor, - c_latent: cute.Tensor, - c_rope: cute.Tensor, - page_table: cute.Tensor, - o: cute.Tensor, - lse: cute.Tensor, - workspace: cute.Tensor, - split_kv: cutlass.Int32, - cache_seqs: Optional[cute.Tensor], - block_split_kvs: Optional[cute.Tensor], - softmax_scale: cutlass.Float32, - output_scale: cutlass.Float32, - stream: cuda.CUstream, - ): - """Execute the Multi-Head Latent Attention operation on the provided tensors. - - The method handles: - 1. Initialization of workspace for temporary split KV buffers - 2. Validation of tensor data types - 3. Initialization of hardware-specific parameters and memory layouts - 4. Configuration of TMA (Tensor Memory Access) operations - 5. Grid and work scheduling computation - 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters - - :param q_latent: The query tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) - :type q_latent: cute.Tensor - :param q_rope: The query RoPE tensor with shape [batch_size, seq_len_q, num_head, rope_dim] (contiguous) - :type q_rope: cute.Tensor - :param c_latent: The key tensor with shape [num_pages, page_size, latent_dim] (contiguous) - :type c_latent: cute.Tensor - :param c_rope: The key RoPE tensor with shape [num_pages, page_size, rope_dim] (contiguous) - :type c_rope: cute.Tensor - :param page_table: The page table tensor with shape [batch_size, page_count] (contiguous) - :type page_table: cute.Tensor - :param o: The output tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) - :type o: cute.Tensor - :param lse: The LSE tensor with shape [batch_size, seq_len_q, num_head] (contiguous) - :type lse: cute.Tensor - :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse - :type workspace: cute.Tensor - :param split_kv: The scalar factor for split KV - :type split_kv: cutlass.Int32 - :param cache_seqs: The cache sequences tensor with shape [batch_size] - :type cache_seqs: cute.Tensor - :param block_split_kvs: The block split KV tensor with shape [batch_size] - :type block_split_kvs: cute.Tensor - :param softmax_scale: The scale factor for softmax - :type softmax_scale: cutlass.Float32 - :param output_scale: The scale factor for the output - :type output_scale: cutlass.Float32 - :param stream: The CUDA stream to execute the kernel on - :type stream: cuda.CUstream - - :raises TypeError: If tensor data types don't match or aren't supported - """ - - # setup static attributes before smem/grid/tma computation - self.q_dtype = q_latent.element_type - self.k_dtype = c_latent.element_type - self.v_dtype = c_latent.element_type - self.o_dtype = o.element_type - - # check type consistency - if cutlass.const_expr( - self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype - ): - raise TypeError( - f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" - ) - - # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] - # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) - def _reinterpret_4d(t): - return cute.make_tensor( - t.iterator, - cute.make_layout( - (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), - stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), - ), - ) - - q_latent = _reinterpret_4d(q_latent) - q_rope = _reinterpret_4d(q_rope) - o = _reinterpret_4d(o) - - # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] - # Input stride: (PS*D, D, 1) → Target: (D, 1, PS*D) - def _reinterpret_3d_kv(t): - return cute.make_tensor( - t.iterator, - cute.make_layout( - (t.shape[1], t.shape[2], t.shape[0]), - stride=(t.stride[1], t.stride[2], t.stride[0]), - ), - ) - - c_latent = _reinterpret_3d_kv(c_latent) - c_rope = _reinterpret_3d_kv(c_rope) - - # Reinterpret contiguous [B, page_count] as [page_count, B] - page_table = cute.make_tensor( - page_table.iterator, - cute.make_layout( - (page_table.shape[1], page_table.shape[0]), - stride=(page_table.stride[1], page_table.stride[0]), - ), - ) - - # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] - # Input stride: (S_q*H, H, 1) → Target: (1, H, S_q*H) - lse = cute.make_tensor( - lse.iterator, - cute.make_layout( - (lse.shape[2], lse.shape[1], lse.shape[0]), - stride=(lse.stride[2], lse.stride[1], lse.stride[0]), - ), - ) - - acc_o, acc_lse = self.initialize_workspace( - q_latent.shape[0], - q_latent.shape[1], - q_latent.shape[2], - q_latent.shape[3], - split_kv, - self.acc_dtype, - workspace, - ) - - c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) - c_latent_transpose = cute.make_tensor( - c_latent.iterator, c_latent_tranpose_layout - ) - - self.q_major_mode = OperandMajorMode.K - self.k_major_mode = OperandMajorMode.K - self.v_major_mode = OperandMajorMode.MN - - self._setup_attributes() - - cta_group = tcgen05.CtaGroup.TWO - # the intermediate tensor p is from smem & k-major - p_major_mode = OperandMajorMode.K - qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( - self.q_dtype, - self.q_major_mode, - self.k_major_mode, - self.acc_dtype, - cta_group, - self.mma_qk_tiler[:2], - ) - pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( - self.v_dtype, - p_major_mode, - self.v_major_mode, - self.acc_dtype, - cta_group, - self.mma_pv_tiler[:2], - ) - - cta_layout_vmnk = cute.tiled_divide( - cute.make_layout(self.cluster_shape_mnk), - (qk_tiled_mma.thr_id.shape,), - ) - - self.epi_tile = self.mma_pv_tiler[:2] - - q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( - qk_tiled_mma, - self.mma_qk_tiler, - self.q_dtype, - (self.iterations_qk_latent * self.load_q_stage), - ) - q_latent_smem_layout_staged = cute.logical_divide( - q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) - ) - q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( - qk_tiled_mma, - self.mma_qk_rope_tiler, - self.q_dtype, - self.load_q_stage, - ) - - kc_latent_smem_layout_staged = sm100_utils.make_smem_layout_b( - qk_tiled_mma, - self.mma_qk_tiler, - self.k_dtype, - (self.iterations_qk_latent * self.load_k_stage), - ) - kc_page_tile_size = min( - self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape - ) - kc_latent_smem_layout_staged = cute.logical_divide( - kc_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) - ) - - kc_latent_smem_layout_for_tma = sm100_utils.make_smem_layout( - OperandMajorMode.K, - (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), - self.k_dtype, - (self.iterations_qk_latent * self.load_k_stage), - ) - kc_latent_smem_layout_for_tma = cute.tiled_divide( - kc_latent_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) - ) - kc_latent_smem_layout_for_tma = cute.logical_divide( - kc_latent_smem_layout_for_tma, (None, None, None, self.iterations_qk_latent) - ) - - kc_rope_smem_layout_staged = sm100_utils.make_smem_layout_b( - qk_tiled_mma, - self.mma_qk_rope_tiler, - self.k_dtype, - self.load_k_stage, - ) - kc_rope_smem_layout_for_tma = sm100_utils.make_smem_layout( - OperandMajorMode.K, - ( - self.mma_qk_rope_tiler[0] // qk_tiled_mma.thr_id.shape, - self.mma_qk_rope_tiler[2], - ), - self.k_dtype, - (self.iterations_qk_rope * self.load_k_stage), - ) - kc_rope_smem_layout_for_tma = cute.tiled_divide( - kc_rope_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_rope_tiler[2]) - ) - - p_smem_layout_staged = sm100_utils.make_smem_layout_a( - pv_tiled_mma, - self.mma_pv_tiler, - self.q_dtype, - (self.iterations_pv_k * self.p_mma_stage), - ) - p_smem_layout_staged = cute.logical_divide( - p_smem_layout_staged, (None, None, None, self.iterations_pv_k) - ) - - vc_smem_layout_staged = sm100_utils.make_smem_layout_b( - pv_tiled_mma, - self.mma_pv_tiler, - self.v_dtype, - (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), - ) - vc_smem_layout_staged = cute.logical_divide( - cute.logical_divide( - vc_smem_layout_staged, - (None, None, None, self.iterations_pv_k * self.iterations_pv_n), - ), - (None, None, None, (self.iterations_pv_n, None)), - ) - vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) - vc_smem_layout_for_tma = sm100_utils.make_smem_layout( - OperandMajorMode.MN, - (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), - self.v_dtype, - (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), - ) - vc_smem_layout_for_tma = cute.tiled_divide( - vc_smem_layout_for_tma, - ( - pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, - vc_page_tile_size, - ), - ) - vc_smem_layout_for_tma = cute.logical_divide( - cute.logical_divide( - vc_smem_layout_for_tma, - (None, None, None, self.iterations_pv_k * self.iterations_pv_n), - ), - (None, None, None, (self.iterations_pv_n, None)), - ) - # TMA load for Q latent and rope - tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) - - q_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) - - tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - q_latent, - q_smem_layout, - self.mma_qk_tiler, - qk_tiled_mma, - cta_layout_vmnk.shape, - ) - q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) - tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( - tma_load_op, - q_rope, - q_rope_smem_layout, - self.mma_qk_rope_tiler, - qk_tiled_mma, - cta_layout_vmnk.shape, - ) - # TMA load for c latent and k rope - kc_smem_layout = cute.select(kc_latent_smem_layout_for_tma, mode=[0]) - tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( - tma_load_op, - c_latent, - kc_smem_layout, - (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), - qk_tiled_mma, - is_k_load=True, - ) - kc_rope_smem_layout = cute.select(kc_rope_smem_layout_for_tma, mode=[0]) - tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( - tma_load_op, - c_rope, - kc_rope_smem_layout, - (self.mma_qk_rope_tiler[1], self.mma_qk_rope_tiler[2]), - qk_tiled_mma, - is_k_load=True, - ) - - # TMA load for c latent transpose - vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) - tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( - self.make_paged_tiled_tma_atom( - tma_load_op, - c_latent_transpose, - vc_smem_layout, - (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), - pv_tiled_mma, - is_k_load=False, - ) - ) - - q_latent_copy_size = ( - cute.size_in_bytes(self.q_dtype, q_smem_layout) - * cute.size(qk_tiled_mma.thr_id.shape) - * self.iterations_qk_latent - ) - q_rope_copy_size = ( - cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) - * cute.size(qk_tiled_mma.thr_id.shape) - * self.iterations_qk_rope - ) - kc_latent_copy_size = ( - cute.size_in_bytes( - self.k_dtype, - cute.select(kc_latent_smem_layout_staged, mode=[0, 1, 2]), - ) - * cute.size(qk_tiled_mma.thr_id.shape) - * self.iterations_qk_latent - ) - kc_rope_copy_size = ( - cute.size_in_bytes( - self.k_dtype, - cute.select(kc_rope_smem_layout_staged, mode=[0, 1, 2]), - ) - * cute.size(qk_tiled_mma.thr_id.shape) - * self.iterations_qk_rope - ) - vc_copy_size = ( - cute.size_in_bytes( - self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) - ) - * cute.size(pv_tiled_mma.thr_id.shape) - * self.iterations_pv_n - * self.iterations_pv_k - ) - - self.tma_copy_q_bytes = q_latent_copy_size + q_rope_copy_size - self.tma_copy_kc_bytes = kc_latent_copy_size + kc_rope_copy_size - self.tma_copy_vc_bytes = vc_copy_size - - tile_sched_params, grid = self._compute_grid( - o, - split_kv, - self.cluster_shape_mnk, - self.max_active_clusters, - self.is_persistent, - ) - - @cute.struct - class SplitKVKernelSharedStorage: - # Pipeline barriers - load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] - load_k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_k_stage * 2] - load_v_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_v_stage * 2] - mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] - p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] - p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] - mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] - - # Smem tensors - smem_p: cute.struct.Align[ - cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], - 1024, - ] - smem_kc_latent: cute.struct.Align[ - cute.struct.MemRange[ - self.k_dtype, cute.cosize(kc_latent_smem_layout_staged) - ], - 1024, - ] - - smem_kc_rope: cute.struct.Align[ - cute.struct.MemRange[ - self.k_dtype, cute.cosize(kc_rope_smem_layout_staged) - ], - 1024, - ] - smem_q_latent: cute.struct.Align[ - cute.struct.MemRange[ - self.q_dtype, cute.cosize(q_latent_smem_layout_staged) - ], - 1024, - ] - smem_q_rope: cute.struct.Align[ - cute.struct.MemRange[ - self.q_dtype, cute.cosize(q_rope_smem_layout_staged) - ], - 1024, - ] - smem_vc: cute.struct.Align[ - cute.struct.MemRange[self.v_dtype, cute.cosize(vc_smem_layout_staged)], - 1024, - ] - softmax_smem_exchange: cute.struct.MemRange[ - self.acc_dtype, self.num_compute_warps * self.threads_per_warp - ] - epilogue_smem_exchange: cute.struct.MemRange[ - self.acc_dtype, self.num_compute_warps * self.threads_per_warp - ] - - # Tmem dealloc cluster barrier - tmem_dealloc_mbar_ptr: cutlass.Int64 - - # Tmem holding buffer - tmem_holding_buf: cutlass.Int32 - - softmax_scale_log2 = softmax_scale * LOG2_E - - self.split_kv_kernel( - qk_tiled_mma, - pv_tiled_mma, - tma_atom_q_latent, - tma_tensor_q_latent, - tma_atom_q_rope, - tma_tensor_q_rope, - tma_atom_c_latent, - tma_tensor_c_latent, - tma_atom_c_rope, - tma_tensor_c_rope, - tma_atom_c_latent_transpose, - tma_tensor_c_latent_transpose, - page_table, - o, - lse, - acc_o, - acc_lse, - split_kv, - cache_seqs, - block_split_kvs, - softmax_scale_log2, - output_scale, - q_latent_smem_layout_staged, - q_rope_smem_layout_staged, - kc_latent_smem_layout_staged, - kc_rope_smem_layout_staged, - p_smem_layout_staged, - vc_smem_layout_staged, - kc_latent_smem_layout_for_tma, - kc_rope_smem_layout_for_tma, - vc_smem_layout_for_tma, - cta_layout_vmnk, - tile_sched_params, - SplitKVKernelSharedStorage, - ).launch( - grid=grid, - block=[self.threads_per_cta, 1, 1], - cluster=self.cluster_shape_mnk, - smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] - stream=stream, - min_blocks_per_mp=1, - use_pdl=self.enable_pdl, - ) - if cutlass.const_expr(acc_o is not None): - self.reduction_kernel( - o, - lse, - acc_o, - acc_lse, - split_kv, - cache_seqs, - block_split_kvs, - ).launch( - grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), - block=[self.threads_per_warp * self.num_compute_warps, 1, 1], - smem=MAX_SPLITS * self.acc_dtype.width // 8, - stream=stream, - min_blocks_per_mp=1, - use_pdl=self.enable_pdl, - ) - - @cute.jit - def make_paged_tiled_tma_atom( - self, - tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, - gmem: cute.Tensor, - smem_layout: cute.Layout, - mma_tiler, - tiled_mma: cute.TiledMma, - is_k_load: bool, - ): - ident = cute.make_identity_layout(gmem.shape) - g_tile = cute.composition(ident, mma_tiler) - cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape - cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) - cta_v_map = cute.select(cta_v_map, mode=[0, 2]) - page_tile_size = ( - min(self.page_size, cta_mn) - if is_k_load - else min(self.page_size, mma_tiler[1]) - ) - cta_v_map = cute.zipped_divide( - cta_v_map, - (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), - ) - cta_v_map = cute.select(cta_v_map, mode=[0]) - from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir - - res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem.value, - smem_layout.value, - cta_v_map, - tma_load_op._to_ir(), - num_multicast=1, - ) - return cute.CopyAtom( - tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) - ), res[1] - - @cute.kernel - def split_kv_kernel( - self, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - tma_atom_q_latent: Optional[cute.CopyAtom], - mQL: cute.Tensor, - tma_atom_q_rope: Optional[cute.CopyAtom], - mQR: cute.Tensor, - tma_atom_c_latent: Optional[cute.CopyAtom], - mCL: cute.Tensor, - tma_atom_c_rope: Optional[cute.CopyAtom], - mKR: cute.Tensor, - tma_atom_c_latent_transpose: Optional[cute.CopyAtom], - mCLT: cute.Tensor, - mPT: cute.Tensor, - mO: Optional[cute.Tensor], - mLSE: Optional[cute.Tensor], - mAccO: Optional[cute.Tensor], - mAccLSE: Optional[cute.Tensor], - split_kv: cutlass.Int32, - cache_seqs: cute.Tensor, - block_split_kvs: cute.Tensor, - softmax_scale_log2: cutlass.Float32, - output_scale: cutlass.Float32, - q_latent_smem_layout_staged: cute.ComposedLayout, - q_rope_smem_layout_staged: cute.ComposedLayout, - kc_latent_smem_layout_staged: cute.ComposedLayout, - kc_rope_smem_layout_staged: cute.ComposedLayout, - p_smem_layout_staged: cute.ComposedLayout, - vc_smem_layout_staged: cute.ComposedLayout, - kc_latent_smem_layout_for_tma: Optional[cute.ComposedLayout], - kc_rope_smem_layout_for_tma: Optional[cute.ComposedLayout], - vc_smem_layout_for_tma: Optional[cute.ComposedLayout], - cta_layout_vmnk: cute.Layout, - tile_sched_params: MLAStaticTileSchedulerParams, - SharedStorage: cutlass.Constexpr, - ): - """The device split_kv kernel implementation of the Multi-Head Latent Attention. - - This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: - 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA - 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) - 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results - to global memory - - The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. - When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results - that will later be combined by a reduction kernel. - - The kernel implements a complex pipeline with overlapping computation and memory operations, - using tensor memory access (TMA) for efficient data loading, warp specialization for different - computation phases. - - :param tiled_mma_qk: Tiled MMA for Q*K^T - :type tiled_mma_qk: cute.TiledMma - :param tiled_mma_pv: Tiled MMA for P*V - :type tiled_mma_pv: cute.TiledMma - :param tma_atom_q_latent: TMA copy atom for query latent tensor - :type tma_atom_q_latent: cute.CopyAtom - :param mQL: query latent tensor - :type mQL: cute.Tensor - :param tma_atom_q_rope: TMA copy atom for query rope tensor - :type tma_atom_q_rope: cute.CopyAtom - :param mKR: Compressed rope tensor - :type mKR: cute.Tensor - :param tma_atom_c_latent: TMA copy atom for c latent tensor - :type tma_atom_c_latent: cute.CopyAtom - :param mCL: Compressed latent tensor - :type mCL: cute.Tensor - :param tma_atom_c_rope: TMA copy atom for c rope tensor - :type tma_atom_c_rope: cute.CopyAtom - :param mCLT: Compressed latent transpose tensor - :type mCLT: cute.Tensor - :param mPT: Page table tensor - :type mPT: cute.Tensor - :param mO: Output tensor - :type mO: cute.Tensor - :param mLSE: Log-sum-exp tensor - :type mLSE: cute.Tensor - :param mAccO: Intermediate accumulator output tensor - :type mAccO: cute.Tensor - :param mAccLSE: Intermediate accumulator log-sum-exp tensor - :type mAccLSE: cute.Tensor - :param split_kv: The split_kv parameter - :type split_kv: cutlass.Int32 - :param cache_seqs: The variable sequence length tensor - :type cache_seqs: cute.Tensor - :param block_split_kvs: The per-block split_kv values tensor - :type block_split_kvs: cute.Tensor - :param softmax_scale_log2: The log2 scale factor for softmax - :type softmax_scale_log2: cutlass.Float32 - :param output_scale: The scale factor for the output - :type output_scale: cutlass.Float32 - :param q_latent_smem_layout_staged: Shared memory layout for query tensor - :type q_latent_smem_layout_staged: cute.ComposedLayout - :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor - :type q_rope_smem_layout_staged: cute.ComposedLayout - :param kc_latent_smem_layout_staged: Shared memory layout for key tensor - :type kc_latent_smem_layout_staged: cute.ComposedLayout - :param kc_rope_smem_layout_staged: Shared memory layout for key rope tensor - :type kc_rope_smem_layout_staged: cute.ComposedLayout - :param p_smem_layout_staged: Shared memory layout for probability matrix - :type p_smem_layout_staged: cute.ComposedLayout - :param vc_smem_layout_staged: Shared memory layout for value tensor - :type vc_smem_layout_staged: cute.ComposedLayout - :param cta_layout_vmnk: Layout for compute threads - :type cta_layout_vmnk: cute.Layout - :param tile_sched_params: Scheduling parameters for work distribution - :type tile_sched_params: MLAStaticTileSchedulerParams - :param SharedStorage: Shared storage for the kernel - :type SharedStorage: cutlass.Constexpr - """ - - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) - is_leader_cta = mma_tile_coord_v == 0 - - # Prefetch tma descriptor - if warp_idx == self.mma_warp_id: - cpasync.prefetch_descriptor(tma_atom_q_latent) - cpasync.prefetch_descriptor(tma_atom_q_rope) - cpasync.prefetch_descriptor(tma_atom_c_latent) - cpasync.prefetch_descriptor(tma_atom_c_rope) - cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) - - # Alloc - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - - # Tensor memory dealloc barrier init - tmem = utils.TmemAllocator( - storage.tmem_holding_buf, - barrier_for_retrieve=self.tmem_ptr_sync_bar, - allocator_warp_id=self.mma_warp_id, - is_two_cta=self.use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, - ) - - load_q_pipeline = self.make_and_init_load_qkv_pipeline( - storage.load_q_mbar_ptr.data_ptr(), - cta_layout_vmnk, - self.load_q_stage, - self.tma_copy_q_bytes, - ) - load_k_pipeline = self.make_and_init_load_qkv_pipeline( - storage.load_k_mbar_ptr.data_ptr(), - cta_layout_vmnk, - self.load_k_stage, - self.tma_copy_kc_bytes, - ) - load_v_pipeline = self.make_and_init_load_qkv_pipeline( - storage.load_v_mbar_ptr.data_ptr(), - cta_layout_vmnk, - self.load_v_stage, - self.tma_copy_vc_bytes, - ) - mma_s_pipeline = self.make_and_init_mma_s_pipeline( - storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk - ) - p_mma_pipeline = self.make_and_init_p_mma_pipeline( - storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk - ) - p_cor_pipeline = self.make_and_init_p_cor_pipeline( - storage.p_cor_mbar_ptr.data_ptr() - ) - mma_o_pipeline = self.make_and_init_mma_o_pipeline( - storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk - ) - - # Cluster arrive after barrier init - pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) - - # Generate smem tensor Q/KC/VC/exchange - # (MMA, MMA_H, MMA_R, PIPE) - sQ = storage.smem_q_latent.get_tensor( - q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner - ) - sQ_rope = storage.smem_q_rope.get_tensor( - q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner - ) - # (MMA, MMA_K, MMA_R, PIPE) - sKC = storage.smem_kc_latent.get_tensor( - kc_latent_smem_layout_staged.outer, - swizzle=kc_latent_smem_layout_staged.inner, - ) - sKC_rope = storage.smem_kc_rope.get_tensor( - kc_rope_smem_layout_staged.outer, swizzle=kc_rope_smem_layout_staged.inner - ) - sKC_for_tma = storage.smem_kc_latent.get_tensor( - kc_latent_smem_layout_for_tma.outer, - swizzle=kc_latent_smem_layout_for_tma.inner, - ) - sKC_rope_for_tma = storage.smem_kc_rope.get_tensor( - kc_rope_smem_layout_for_tma.outer, swizzle=kc_rope_smem_layout_for_tma.inner - ) - # (MMA, MMA_D, MMA_K, PIPE) - sVC = storage.smem_vc.get_tensor( - vc_smem_layout_staged.outer, swizzle=vc_smem_layout_staged.inner - ) - sVC_for_tma = storage.smem_vc.get_tensor( - vc_smem_layout_for_tma.outer, swizzle=vc_smem_layout_for_tma.inner - ) - # (MMA, MMA_H, MMA_K) - sP = storage.smem_p.get_tensor( - p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner - ) - # (compute_threads,) - softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( - cute.make_layout(self.num_compute_warps * self.threads_per_warp) - ) - epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( - cute.make_layout(self.num_compute_warps * self.threads_per_warp) - ) - - # - # Cluster wait before tensor memory alloc - # - pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) - - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_wait() - - # /////////////////////////////////////////////////////////////////////////////// - # Load warps, including page table and data tensors - # /////////////////////////////////////////////////////////////////////////////// - if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: - _setmaxregister_decrease(self.other_reg_num) - - if warp_idx == self.load_tma_k_warp_id: - _setmaxregister_decrease(self.other_reg_num) - load_q_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_q_stage - ) - load_k_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_k_stage - ) - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, - cache_seqs, - block_split_kvs, - blk_coord, - ) - if k_tile_count > 0: - # Construct fixed common/tma_qk/tma_pv params for load_tma - tma_common_params = SimpleNamespace( - blk_coord=blk_coord, - local_split_kv=local_split_kv, - load_q_pipeline=load_q_pipeline, - load_k_pipeline=load_k_pipeline, - load_v_pipeline=load_v_pipeline, - mPT=mPT, - ) - tma_qk_params = SimpleNamespace( - tiled_mma_qk=tiled_mma_qk, - tma_atom_q_latent=tma_atom_q_latent, - tma_atom_q_rope=tma_atom_q_rope, - tma_atom_c_latent=tma_atom_c_latent, - tma_atom_c_rope=tma_atom_c_rope, - mQL=mQL, - mQR=mQR, - mCL=mCL, - mKR=mKR, - sQ=sQ, - sQ_rope=sQ_rope, - sKC=sKC_for_tma, - sKC_rope=sKC_rope_for_tma, - ) - # Load tma - load_q_producer_state, load_k_producer_state = self.load_tma_qk( - tma_common_params, - tma_qk_params, - k_index, - k_tile_count, - load_q_producer_state, - load_k_producer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - load_q_pipeline.producer_tail(load_q_producer_state) - load_k_pipeline.producer_tail(load_k_producer_state) - - if warp_idx == self.load_tma_v_warp_id: - _setmaxregister_decrease(self.other_reg_num) - load_v_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.load_v_stage - ) - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, - cache_seqs, - block_split_kvs, - blk_coord, - ) - if k_tile_count > 0: - # Construct fixed common/tma_qk/tma_pv params for load_tma - tma_common_params = SimpleNamespace( - blk_coord=blk_coord, - local_split_kv=local_split_kv, - load_v_pipeline=load_v_pipeline, - mPT=mPT, - ) - tma_pv_params = SimpleNamespace( - tiled_mma_pv=tiled_mma_pv, - tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, - mCLT=mCLT, - sVC=sVC_for_tma, - ) - # Load tma - load_v_producer_state = self.load_tma_v( - tma_common_params, - tma_pv_params, - k_index, - k_tile_count, - load_v_producer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - load_v_pipeline.producer_tail(load_v_producer_state) - - # /////////////////////////////////////////////////////////////////////////////// - # MMA warp - # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.mma_warp_id: - _setmaxregister_decrease(self.other_reg_num) - # Alloc tensor memory buffer - tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - - load_q_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_q_stage - ) - load_k_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_k_stage - ) - load_v_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_v_stage - ) - mma_s_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_s_stage - ) - p_mma_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.p_mma_stage - ) - mma_o_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_o_stage - ) - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, cache_seqs, block_split_kvs, blk_coord - ) - if k_tile_count > 0: - mma_common_params = SimpleNamespace( - blk_coord=blk_coord, - local_split_kv=local_split_kv, - load_q_pipeline=load_q_pipeline, - load_k_pipeline=load_k_pipeline, - load_v_pipeline=load_v_pipeline, - tmem_ptr=tmem_ptr, - is_leader_cta=is_leader_cta, - L=mCL.shape[1], - ) - mma_qk_params = SimpleNamespace( - mma_s_pipeline=mma_s_pipeline, - sQ=sQ, - sQ_rope=sQ_rope, - sKC=sKC, - sKC_rope=sKC_rope, - ) - mma_pv_params = SimpleNamespace( - p_mma_pipeline=p_mma_pipeline, - mma_o_pipeline=mma_o_pipeline, - sP=sP, - sVC=sVC, - ) - ( - tiled_mma_qk, - tiled_mma_pv, - load_q_consumer_state, - load_k_consumer_state, - load_v_consumer_state, - mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma( - mma_common_params, - mma_qk_params, - mma_pv_params, - k_tile_count, - tiled_mma_qk, - tiled_mma_pv, - load_q_consumer_state, - load_k_consumer_state, - load_v_consumer_state, - mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - mma_s_pipeline.producer_tail(mma_s_producer_state) - mma_o_pipeline.producer_tail(mma_o_producer_state) - - tmem.relinquish_alloc_permit() - tmem.free(tmem_ptr) - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_launch_dependents() - - # /////////////////////////////////////////////////////////////////////////////// - # Compute warp - # /////////////////////////////////////////////////////////////////////////////// - if ( - warp_idx >= self.compute_warp_ids[0] - and warp_idx <= self.compute_warp_ids[-1] - ): - _setmaxregister_increase(self.softmax_reg_num) - mma_s_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_s_stage - ) - p_mma_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.p_mma_stage - ) - p_cor_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.p_cor_stage - ) - mma_o_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_o_stage - ) - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, cache_seqs, block_split_kvs, blk_coord - ) - if k_tile_count > 0: - compute_common_params = SimpleNamespace( - blk_coord=blk_coord, - split_kv=split_kv, - local_split_kv=local_split_kv, - smem_exchange=softmax_smem_exchange, - mAccO=mAccO, - mO=mO, - K=cache_seqs[blk_coord[2]], - L=mCL.shape[1], - tmem_ptr=tmem_ptr, - tidx=tidx, - p_cor_pipeline=p_cor_pipeline, - ) - compute_softmax_params = SimpleNamespace( - tiled_mma_qk=tiled_mma_qk, - sP=sP, - mma_s_pipeline=mma_s_pipeline, - p_mma_pipeline=p_mma_pipeline, - softmax_scale_log2=softmax_scale_log2, - ) - mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( - self.compute( - compute_common_params, - compute_softmax_params, - k_index=k_index, - k_tile_count=k_tile_count, - mma_s_consumer_state=mma_s_consumer_state, - p_mma_producer_state=p_mma_producer_state, - p_cor_producer_state=p_cor_producer_state, - ) - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - p_cor_pipeline.producer_tail(p_cor_producer_state) - - # /////////////////////////////////////////////////////////////////////////////// - # Correction warp - # /////////////////////////////////////////////////////////////////////////////// - if ( - warp_idx >= self.correction_warp_ids[0] - and warp_idx <= self.correction_warp_ids[-1] - ): - _setmaxregister_increase(self.correction_reg_num) - p_cor_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.p_cor_stage - ) - mma_o_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.mma_o_stage - ) - # sync with mma warp before retrieving tmem ptr - tmem.wait_for_alloc() - - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - - tile_sched = create_mla_static_tile_scheduler( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - blk_coord = work_tile.tile_idx - k_index, k_tile_count, local_split_kv = self.get_k_tile_count( - split_kv, cache_seqs, block_split_kvs, blk_coord - ) - if k_tile_count > 0: - compute_common_params = SimpleNamespace( - blk_coord=blk_coord, - split_kv=split_kv, - local_split_kv=local_split_kv, - smem_exchange=epilogue_smem_exchange, - mAccO=mAccO, - mO=mO, - K=cache_seqs[blk_coord[2]], - L=mCL.shape[1], - H=mQL.shape[0], - tmem_ptr=tmem_ptr, - tidx=tidx, - tiled_mma_pv=tiled_mma_pv, - p_cor_pipeline=p_cor_pipeline, - mma_o_pipeline=mma_o_pipeline, - ) - compute_epilogue_params = SimpleNamespace( - output_scale=output_scale, - softmax_scale_log2=softmax_scale_log2, - mAccLSE=mAccLSE, - mLSE=mLSE, - ) - p_cor_consumer_state, mma_o_consumer_state = self.correction( - compute_common_params, - compute_epilogue_params, - k_tile_count=k_tile_count, - p_cor_consumer_state=p_cor_consumer_state, - mma_o_consumer_state=mma_o_consumer_state, - ) - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - return - - @cute.kernel - def reduction_kernel( - self, - mO: cute.Tensor, - mLSE: cute.Tensor, - mAccO: cute.Tensor, - mAccLSE: cute.Tensor, - split_kv: cutlass.Int32, - cache_seqs: cute.Tensor, - block_split_kvs: cute.Tensor, - ): - """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results - from multiple split_kv blocks into final outputs. - - :param mO: Output tensor for storing final results - :type mO: cute.Tensor - :param mLSE: Log-sum-exp tensor for storing final LSE values - :type mLSE: cute.Tensor - :param mAccO: Accumulated output tensor from split_kv blocks - :type mAccO: cute.Tensor - :param mAccLSE: Accumulated LSE tensor from split_kv blocks - :type mAccLSE: cute.Tensor - :param split_kv: Number of split_kv blocks - :type split_kv: cutlass.Int32 - :param cache_seqs: Cache sequence lengths tensor - :type cache_seqs: cute.Tensor - :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) - :type block_split_kvs: cute.Tensor - """ - bidx, bidy, bidz = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() - blk_coord = (bidx, bidy, bidz) - local_split_kv = ( - block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv - ) - k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) - k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) - local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) - - # Alloc shared memory - smem = utils.SmemAllocator() - storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) - lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) - smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) - - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_wait() - gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if warp_idx == 0: - # calculate the global lse and exp ^ (local_lse - global_lse) - lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) - - local_lse = cute.make_rmem_tensor( - cute.make_layout(lse_per_thread), self.lse_dtype - ) - lse_max = -self.lse_dtype.inf - # find the max lse - for i in cutlass.range_constexpr(lse_per_thread): - split_kv_idx = tidx + i * self.threads_per_warp - local_lse[i] = ( - gLSE[split_kv_idx] - if cute.elem_less(split_kv_idx, local_split_kv) - else -self.lse_dtype.inf - ) - # reduce the local lse - lse_max = cute.arch.fmax(lse_max, local_lse[i]) - lse_max = cute.arch.warp_reduction_max(lse_max) - lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 - # calculate sum_lse - sum_lse = 0.0 - for i in cutlass.range_constexpr(lse_per_thread): - sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) - sum_lse = cute.arch.warp_reduction_sum(sum_lse) - # calculate the global_lse - global_lse = ( - lse_max + cute.math.log2(sum_lse, fastmath=True) - if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 - else self.lse_dtype.inf - ) - if tidx == 0: - mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse - # store the scale to shared memory - for i in cutlass.range_constexpr(lse_per_thread): - split_kv_idx = tidx + i * self.threads_per_warp - if cute.elem_less(split_kv_idx, local_split_kv): - smem_lse_scale[split_kv_idx] = cute.math.exp2( - local_lse[i] - global_lse, fastmath=True - ) - - pipeline.sync(barrier_id=4) - - elements_per_thread = cute.ceil_div( - self.latent_dim, self.threads_per_warp * self.num_compute_warps - ) - gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] - rAccO = cute.make_rmem_tensor( - cute.make_layout(elements_per_thread), self.acc_dtype - ) - rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) - rAccO.fill(0.0) - for i in range(local_split_kv): - for j in cutlass.range_constexpr(elements_per_thread): - element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps - rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] - rO.store(rAccO.load().to(self.o_dtype)) - for j in cutlass.range_constexpr(elements_per_thread): - element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps - mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] - if cutlass.const_expr(self.enable_pdl): - cute.arch.griddepcontrol_launch_dependents() - return - - @staticmethod - def get_split_kv( - B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int - ) -> int: - """Get the proper split_kv value for the MLA kernel based on parameters. - - :param B: Batch size - :type B: int - :param S: Sequence length - :type S: int - :param K: Sequence length - :type K: int - :param mma_qk_tiler_mn: MLA tiling parameters - :type mma_qk_tiler_mn: tuple - :param max_active_blocks: Maximum number of active blocks - :type max_active_blocks: int - :return: Split_kv value - :rtype: int - """ - max_splits = ceil_div(K, mma_qk_tiler_mn[1]) - blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) - split_heur = min(max_splits, blocks_per_batch) - # {$nv-internal-release begin} - # TODO: figure out the error of make_tile with dynamic int_tuple - # {$nv-internal-release end} - k_waves = ceil_div(max_splits, split_heur) - split_wave_aware = ceil_div(max_splits, k_waves) - max_split_kv = 32 - return min(split_wave_aware, max_split_kv) - - @staticmethod - def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: - blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) - max_split_kv = 32 - return min(blocks_per_batch, max_split_kv) - - @cute.jit - def get_k_tile_count( - self, - split_kv: cutlass.Int32, - cache_seqs: cute.Tensor, - block_split_kvs: cute.Tensor, - blk_coord: cute.Coord, - ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: - """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. - - :param split_kv: Split_kv value - :type split_kv: cutlass.Int32 - :param cache_seqs: Cache sequence lengths tensor - :type cache_seqs: cute.Tensor - :param block_split_kvs: Per-block split_kv values tensor - :type block_split_kvs: cute.Tensor - :param blk_coord: Block coordinate - :type blk_coord: cute.Coord - :return: k_index, k_tile_count, split_kv - :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] - """ - K = cache_seqs[blk_coord[2]] - if cutlass.const_expr(self.is_var_split_kv): - split_kv = block_split_kvs[blk_coord[2]] - - k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) - # {$nv-internal-release begin} - # TODO: figure out the error of make_tile with dynamic int_tuple - # {$nv-internal-release end} - k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) - k_index = blk_coord[3] * k_tile_per_cta - k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) - return k_index, k_tile_count, split_kv - - @cute.jit - def load_tma_qk( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - load_q_producer_state: pipeline.PipelineState | None = None, - load_k_producer_state: pipeline.PipelineState | None = None, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: - """Load wrap to load Q/K tensors. Updates the load qk producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param qk_params: The qk parameters - :type qk_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param load_q_producer_state: The load q producer state - :type load_q_producer_state: pipeline.PipelineState - :param load_k_producer_state: The load k producer state - :type load_k_producer_state: pipeline.PipelineState - - :return: The load q producer state and load k producer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] - """ - # page table - mPT = common_params.mPT[None, common_params.blk_coord[2]] - - # Flatten divide and partition global tensors for QK TMA load - # (bM, bK, rM, rK, rL) - mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) - gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) - mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) - gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) - - thr_mma_qk = qk_params.tiled_mma_qk.get_slice( - common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) - ) - tSgQL = thr_mma_qk.partition_A(gQL) - tSgQR = thr_mma_qk.partition_A(gQR) - - cta_m = min( - qk_params.tiled_mma_qk.op.shape_mnk[0] - // qk_params.tiled_mma_qk.thr_id.shape, - self.page_size, - ) - page_tile_size = min(self.page_size, cta_m) - gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) - tSgCL = ( - gCL[ - None, - common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, - None, - None, - ] - if cta_m < self.page_size - else gCL[None, 0, None, None] - ) - gKR = cute.tiled_divide( - qk_params.mKR, (page_tile_size, self.mma_qk_rope_tiler[2]) - ) - tSgKR = ( - gKR[ - None, - common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, - None, - None, - ] - if cta_m < self.page_size - else gKR[None, 0, None, None] - ) - # tma partition for q, k latent/rope - - # smem: ((atom_v, rest_v), STAGE) - # gmem: ((atom_v, rest_v), RestM, RestK, RestL) - tQsQ, tQLgQL_mkl = cpasync.tma_partition( - qk_params.tma_atom_q_latent, - 0, - cute.make_layout(1), - cute.group_modes(qk_params.sQ, 0, 3), - cute.group_modes(tSgQL, 0, 3), - ) - - tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( - qk_params.tma_atom_q_rope, - 0, - cute.make_layout(1), - cute.group_modes(qk_params.sQ_rope, 0, 3), - cute.group_modes(tSgQR, 0, 3), - ) - tKCsKC, tCLgCL = cpasync.tma_partition( - qk_params.tma_atom_c_latent, - 0, - cute.make_layout(1), - qk_params.sKC, - tSgCL, - ) - - tKCsKC_rope, tKRgKR = cpasync.tma_partition( - qk_params.tma_atom_c_rope, - 0, - cute.make_layout(1), - qk_params.sKC_rope, - tSgKR, - ) - - tQLgQL = tQLgQL_mkl[ - None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] - ] - tQRgQR = tQRgQR_mkl[ - None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] - ] - - # set extra params - common_params.mPT = mPT - qk_params.tQLgQL = tQLgQL - qk_params.tQRgQR = tQRgQR - qk_params.tCLgCL = tCLgCL - qk_params.tKRgKR = tKRgKR - qk_params.tQsQ = tQsQ - qk_params.tQsQ_rope = tQsQ_rope - qk_params.tKCsKC = tKCsKC - qk_params.tKCsKC_rope = tKCsKC_rope - - k_tile_count_init = k_tile_count - while k_tile_count > 0: - # {$nv-internal-release begin} - # TODO: figure out how to support SingleNamespace/struct in ast - # {$nv-internal-release end} - load_q_producer_state, load_k_producer_state = self.load_tma_qk_one_k_tile( - common_params, - qk_params, - k_index, - k_tile_count, - load_q_producer_state, - load_k_producer_state, - load_q=k_tile_count_init == k_tile_count, - ) - k_index += 1 - k_tile_count -= 1 - - return load_q_producer_state, load_k_producer_state - - @cute.jit - def load_tma_v( - self, - common_params: SimpleNamespace, - v_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - load_v_producer_state: pipeline.PipelineState, - ) -> pipeline.PipelineState: - """Load wrap to load V tensors. Updates the load v producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param v_params: The v parameters - :type v_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param load_v_producer_state: The load v producer state - :type load_v_producer_state: pipeline.PipelineState - - :return: The load v producer state - :rtype: pipeline.PipelineState - """ - # page table - mPT = common_params.mPT[None, common_params.blk_coord[2]] - - # Flatten divide and partition global tensors for V TMA load - page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) - gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) - cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape - gCLT = cute.logical_divide(gCLT, (cta_n,))[ - (None, common_params.blk_coord[0]), None, None, None, None - ] - tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) - tOgCLT = tOgCLT[None, 0, 0, None, None, None] - # tma partition for vc - # smem: ((atom_v, rest_v), STAGE) - # gmem: ((atom_v, rest_v), RestM, RestK, RestL) - tVCsVC, tCLTgCLT = cpasync.tma_partition( - v_params.tma_atom_c_latent_transpose, - 0, - cute.make_layout(1), - v_params.sVC, - tOgCLT, - ) - - # set extra params - common_params.mPT = mPT - v_params.tCLTgCLT = tCLTgCLT - v_params.tVCsVC = tVCsVC - - while k_tile_count > 0: - # {$nv-internal-release begin} - # TODO: figure out how to support SingleNamespace/struct in ast - # {$nv-internal-release end} - load_v_producer_state = self.load_tma_v_one_k_tile( - common_params, - v_params, - k_index, - load_v_producer_state, - ) - k_index += 1 - k_tile_count -= 1 - return load_v_producer_state - - @cute.jit - def load_tma_qk_one_k_tile( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - load_q_producer_state: pipeline.PipelineState, - load_k_producer_state: pipeline.PipelineState, - load_q: bool, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: - """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param qk_params: The qk parameters - :type qk_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param load_q_producer_state: The load q producer state - :type load_q_producer_state: pipeline.PipelineState - :param load_k_producer_state: The load kv producer state - :type load_k_producer_state: pipeline.PipelineState - :param load_q: Whether to load q - :type load_q: bool - - :return: The load q producer state and load kv producer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] - """ - page_per_tile = ceil_div( - self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape - ) - k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) - for i in cutlass.range_constexpr(page_per_tile): - k_idx[i] = ( - common_params.mPT[k_index] - if self.mma_qk_tiler[1] // self.page_size == 1 - else common_params.mPT[ - ( - k_index * qk_params.tiled_mma_qk.thr_id.shape - + common_params.blk_coord[0] - ) - * page_per_tile - + i - ] - ) - # load q once at first iteration - load_q_pipeline = common_params.load_q_pipeline - if load_q: - # get the mbar ptr from pipeline. - tma_bar_ptr = load_q_pipeline.producer_get_barrier(load_q_producer_state) - # expect the extra bytes for q. - load_q_pipeline.producer_acquire(load_q_producer_state) - for i in cutlass.range_constexpr(self.iterations_qk_latent): - # load q latent - cute.copy( - qk_params.tma_atom_q_latent, - qk_params.tQLgQL[None, 0, i], - qk_params.tQsQ[None, (i, 0)], - tma_bar_ptr=tma_bar_ptr, - ) - for i in cutlass.range_constexpr(self.iterations_qk_rope): - # load q rope - cute.copy( - qk_params.tma_atom_q_rope, - qk_params.tQRgQR[None, 0, i], - qk_params.tQsQ_rope[None, i], - tma_bar_ptr=tma_bar_ptr, - ) - load_q_producer_state.advance() - # get the mbar ptr from pipeline. - tma_bar_ptr = common_params.load_k_pipeline.producer_get_barrier( - load_k_producer_state - ) - common_params.load_k_pipeline.producer_acquire(load_k_producer_state) - for i in range(self.iterations_qk_latent): - for k in range(page_per_tile): - # load k latent - cute.copy( - qk_params.tma_atom_c_latent, - qk_params.tCLgCL[None, i, k_idx[k]], - qk_params.tKCsKC[None, k, 0, (i, load_k_producer_state.index)], - tma_bar_ptr=tma_bar_ptr, - ) - - for i in cutlass.range_constexpr(self.iterations_qk_rope): - for k in cutlass.range_constexpr(page_per_tile): - # load k rope - cute.copy( - qk_params.tma_atom_c_rope, - qk_params.tKRgKR[None, i, k_idx[k]], - qk_params.tKCsKC_rope[None, k, 0, load_k_producer_state.index], - tma_bar_ptr=tma_bar_ptr, - ) - load_k_producer_state.advance() - - return load_q_producer_state, load_k_producer_state - - @cute.jit - def load_tma_v_one_k_tile( - self, - common_params: SimpleNamespace, - v_params: SimpleNamespace, - k_index: cutlass.Int32, - load_v_producer_state: pipeline.PipelineState, - ) -> pipeline.PipelineState: - """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param v_params: The load tma v parameters - :type v_params: SimpleNamespace - :param k_index: The k index - :type k_index: cutlass.Int32 - :param load_v_producer_state: The load v producer state - :type load_v_producer_state: pipeline.PipelineState - - :return: The load qkv producer state - :rtype: pipeline.PipelineState - """ - page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size - page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) - k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) - for i in cutlass.range_constexpr(page_per_tile): - k_idx[i] = ( - common_params.mPT[k_index] - if page_per_tile == 1 - else common_params.mPT[k_index * page_per_tile + i] - ) - # get the mbar ptr from pipeline. - tma_bar_ptr = common_params.load_v_pipeline.producer_get_barrier( - load_v_producer_state - ) - common_params.load_v_pipeline.producer_acquire(load_v_producer_state) - for j in cutlass.range_constexpr(self.iterations_pv_n): - for i in cutlass.range_constexpr(self.iterations_pv_k): - if cutlass.const_expr(page_per_tile > 1): - for k in cutlass.range_constexpr(page_per_subtile): - k_idx_i = k_idx[k + i * page_per_subtile] - cute.copy( - v_params.tma_atom_c_latent_transpose, - v_params.tCLTgCLT[None, j, 0, k_idx_i], - v_params.tVCsVC[ - None, 0, k, ((j, i), load_v_producer_state.index) - ], - tma_bar_ptr=tma_bar_ptr, - ) - else: - cute.copy( - v_params.tma_atom_c_latent_transpose, - v_params.tCLTgCLT[None, j, i, k_idx[0]], - v_params.tVCsVC[ - None, 0, 0, ((j, i), load_v_producer_state.index) - ], - tma_bar_ptr=tma_bar_ptr, - ) - load_v_producer_state.advance() - return load_v_producer_state - - @cute.jit - def mma( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - pv_params: SimpleNamespace, - k_tile_count: cutlass.Int32, - tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, - load_q_consumer_state: pipeline.PipelineState, - load_k_consumer_state: pipeline.PipelineState, - load_v_consumer_state: pipeline.PipelineState, - mma_s_producer_state: pipeline.PipelineState, - p_mma_consumer_state: pipeline.PipelineState, - mma_o_producer_state: pipeline.PipelineState, - ) -> tuple[ - cute.TiledMma, - cute.TiledMma, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. - - :param common_params: The common parameters for mma qk and pv - :type common_params: SimpleNamespace - :param qk_params: The mma qk parameters - :type qk_params: SimpleNamespace - :param pv_params: The mma pv parameters - :type pv_params: SimpleNamespace - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param tiled_mma_qk: The tiled mma qk - :type tiled_mma_qk: cute.TiledMma - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param load_q_consumer_state: The load q consumer state - :type load_q_consumer_state: pipeline.PipelineState - :param load_k_consumer_state: The load k consumer state - :type load_k_consumer_state: pipeline.PipelineState - :param load_v_consumer_state: The load v consumer state - :type load_v_consumer_state: pipeline.PipelineState - :param mma_s_producer_state: The mma s producer state - :type mma_s_producer_state: pipeline.PipelineState - :param p_mma_consumer_state: The p mma consumer state - :type p_mma_consumer_state: pipeline.PipelineState - :param mma_o_producer_state: The mma o producer state - :type mma_o_producer_state: pipeline.PipelineState - - :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load k consumer state, the load v consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state - :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - - tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) - tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) - tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) - tSrKC_rope = tiled_mma_qk.make_fragment_B(qk_params.sKC_rope) - tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) - tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) - - tStS_shape = tiled_mma_qk.partition_shape_C( - cute.select(self.mma_qk_tiler, mode=[0, 1]) - ) - tStS_staged_fake = tiled_mma_qk.make_fragment_C( - cute.append(tStS_shape, self.mma_s_stage) - ) - # use real tmem ptr for tStS - tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) - tOtO_shape = tiled_mma_pv.partition_shape_C( - cute.select(self.mma_pv_tiler, mode=[0, 1]) - ) - # mma O has 1 stage. - tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) - tOtO_layout = cute.append( - tOtO.layout, - cute.make_layout( - common_params.L // self.mma_pv_tiler[1], - stride=self.mma_pv_tiler[1] // self.warps_in_n, - ), - ) - tOtO_staged = cute.make_tensor( - tStS_staged.iterator + self.tmem_o_offset, tOtO_layout - ) - - # set more parameters - qk_params.tSrQ = tSrQ - qk_params.tSrQ_rope = tSrQ_rope - qk_params.tSrKC = tSrKC - qk_params.tSrKC_rope = tSrKC_rope - qk_params.tStS_staged = tStS_staged - pv_params.tOrP = tOrP - pv_params.tOrVC = tOrVC - pv_params.tOtO_staged = tOtO_staged - - # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) - load_q_pipeline = common_params.load_q_pipeline - if common_params.is_leader_cta: - load_q_release_state = load_q_consumer_state.clone() - ( - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - ) = self.mma_qk( - common_params, - qk_params, - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - wait_q=True, - ) - k_tile_count -= 1 - - while k_tile_count > 0: - ( - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - ) = self.mma_qk( - common_params, - qk_params, - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - wait_q=False, - ) - ( - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma_pv( - common_params, - pv_params, - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - k_tile_count -= 1 - # release q consumer states - load_q_pipeline.consumer_release(load_q_release_state) - load_q_release_state.advance() - ( - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma_pv( - common_params, - pv_params, - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - - return ( # type: ignore[return-value] - tiled_mma_qk, - tiled_mma_pv, - load_q_consumer_state, - load_k_consumer_state, - load_v_consumer_state, - mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - - @cute.jit - def mma_qk( - self, - common_params: SimpleNamespace, - qk_params: SimpleNamespace, - tiled_mma_qk: cute.TiledMma, - load_q_consumer_state: pipeline.PipelineState, - load_k_consumer_state: pipeline.PipelineState, - mma_s_producer_state: pipeline.PipelineState, - wait_q: bool, - ) -> tuple[ - cute.TiledMma, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. - - :param qk_params: The qk parameters - :type qk_params: SimpleNamespace - :param tiled_mma_qk: The tiled mma qk - :type tiled_mma_qk: cute.TiledMma - :param load_q_consumer_state: The load q consumer state - :type load_q_consumer_state: pipeline.PipelineState - :param load_k_consumer_state: The load k consumer state - :type load_k_consumer_state: pipeline.PipelineState - :param mma_s_producer_state: The mma s producer state - :type mma_s_producer_state: pipeline.PipelineState - - :return: The tiled mma qk, the load q consumer state, the load k consumer state, and the mma s producer state - :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] - - qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) - tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) - load_q_pipeline = common_params.load_q_pipeline - load_k_pipeline = common_params.load_k_pipeline - if cutlass.const_expr(wait_q): - load_q_pipeline.consumer_wait(load_q_consumer_state) - load_k_pipeline.consumer_wait(load_k_consumer_state) - for q_stage in range(self.iterations_qk_latent): - kc_stage = load_k_consumer_state.index - for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])): - cute.gemm( - tiled_mma_qk, - tStS, - qk_params.tSrQ[None, None, k_block, (q_stage, 0)], - qk_params.tSrKC[None, None, k_block, (q_stage, kc_stage)], - tStS, - ) - tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) - - for q_stage in range(self.iterations_qk_rope): - kc_stage = load_k_consumer_state.index - for k_block in cutlass.range_constexpr( - self.rope_dim // tiled_mma_qk.shape_mnk[2] - ): - cute.gemm( - tiled_mma_qk, - tStS, - qk_params.tSrQ_rope[None, None, k_block, q_stage], - qk_params.tSrKC_rope[None, None, k_block, kc_stage], - tStS, - ) - tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) - load_k_pipeline.consumer_release(load_k_consumer_state) - load_k_consumer_state.advance() - if cutlass.const_expr(wait_q): - load_q_consumer_state.advance() - - qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) - mma_s_producer_state.advance() - return ( - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - ) - - @cute.jit - def mma_pv( - self, - common_params: SimpleNamespace, - pv_params: SimpleNamespace, - tiled_mma_pv: cute.TiledMma, - load_v_consumer_state: pipeline.PipelineState, - p_mma_consumer_state: pipeline.PipelineState, - mma_o_producer_state: pipeline.PipelineState, - ) -> tuple[ - cute.TiledMma, - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - ]: - """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param pv_params: The pv parameters - :type pv_params: SimpleNamespace - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param load_v_consumer_state: The load v consumer state - :type load_v_consumer_state: pipeline.PipelineState - :param p_mma_consumer_state: The P MMA consumer state - :type p_mma_consumer_state: pipeline.PipelineState - :param mma_o_producer_state: The MMA o producer state - :type mma_o_producer_state: pipeline.PipelineState - - :return: The tiled mma pv, the load v consumer state, the P MMA consumer state, and the MMA o producer state - :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - - pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) - load_v_pipeline = common_params.load_v_pipeline - accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) - mma_o_pipeline = pv_params.mma_o_pipeline - - load_v_pipeline.consumer_wait(load_v_consumer_state) - vc_stage = load_v_consumer_state.index - for acc_stage in range(self.iterations_pv_n): - mma_o_pipeline.producer_acquire(mma_o_producer_state) - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) - for p_stage in range(self.iterations_pv_k): - tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] - for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]): - cute.gemm( - tiled_mma_pv, - tOtO, - pv_params.tOrP[ - None, - None, - k_block, - (p_stage, p_mma_consumer_state.index), - ], - pv_params.tOrVC[ - None, None, k_block, ((acc_stage, p_stage), vc_stage) - ], - tOtO, - ) - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) - - mma_o_pipeline.producer_commit(mma_o_producer_state) - mma_o_producer_state.advance() - load_v_pipeline.consumer_release(load_v_consumer_state) - load_v_consumer_state.advance() - pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) - p_mma_consumer_state.advance() - - return ( - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) - - @cute.jit - def compute( - self, - common_params: SimpleNamespace, - softmax_params: SimpleNamespace, - k_index: cutlass.Int32, - k_tile_count: cutlass.Int32, - mma_s_consumer_state: pipeline.PipelineState, - p_mma_producer_state: pipeline.PipelineState, - p_cor_producer_state: pipeline.PipelineState, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: - """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param softmax_params: The softmax parameters - :type softmax_params: SimpleNamespace - :param k_index: The index of the k-tile - :type k_index: cutlass.Int32 - :param k_tile_count: The number of k-tiles - :type k_tile_count: cutlass.Int32 - :param mma_s_consumer_state: The MMA s consumer state - :type mma_s_consumer_state: pipeline.PipelineState - :param p_mma_producer_state: The P MMA producer state - :type p_mma_producer_state: pipeline.PipelineState - :param p_cor_producer_state: The P correction producer state - :type p_cor_producer_state: pipeline.PipelineState - - :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] - """ - - k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) - - row_max = -self.acc_dtype.inf - row_sum = self.acc_dtype(0) - correction_factor = self.acc_dtype(1) - common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - - # no mask applied - while k_tile_count > 1: - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - False, - False, - ) - k_index = k_index + 1 - k_tile_count = k_tile_count - 1 - - # mask applied - if cutlass.const_expr(common_params.mAccO is not None): - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - k_index == k_tile_total - 1, - True, - ) - else: - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - True, - True, - ) - - return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state - - @cute.jit - def correction( - self, - common_params: SimpleNamespace, - epilogue_params: SimpleNamespace, - k_tile_count: cutlass.Int32, - p_cor_consumer_state: pipeline.PipelineState, - mma_o_consumer_state: pipeline.PipelineState, - ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: - """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param epilogue_params: The epilogue parameters - :type epilogue_params: SimpleNamespace - :param k_index: The index of the k-tile - :type k_index: cutlass.Int32 - :param k_tile_count: The number of k-tiles - :type k_tile_count: cutlass.Int32 - :param p_cor_consumer_state: The P correction consumer state - :type p_cor_consumer_state: pipeline.PipelineState - :param mma_o_consumer_state: The MMA o consumer state - :type mma_o_consumer_state: pipeline.PipelineState - - :return: The P correction consumer state, and the MMA o consumer state - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] - """ - - k_tile_count_init = k_tile_count - while k_tile_count > 0: - p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( - self.get_correction_factor(common_params, p_cor_consumer_state) - ) - if k_tile_count_init != k_tile_count: - mma_o_consumer_state = self.rescale( - common_params, - mma_o_consumer_state, - correction_factor, - no_correction, - ) - k_tile_count = k_tile_count - 1 - if k_tile_count == 0: - mma_o_consumer_state = self.epilogue( - common_params, - epilogue_params, - mma_o_consumer_state, - row_sum, - row_max, - ) - return p_cor_consumer_state, mma_o_consumer_state - - @cute.jit - def exchange_p_cor_metadata( - self, - common_params: SimpleNamespace, - softmax_params: SimpleNamespace, - correction_factor: cutlass.Float32, - row_sum: cutlass.Float32, - row_max: cutlass.Float32, - row_max_new: cutlass.Float32, - tAcc: cute.Tensor, - tidx: cutlass.Int32, - p_cor_producer_state: pipeline.PipelineState, - ) -> tuple[pipeline.PipelineState, cutlass.Float32]: - """Compute the correction factor for the last k tile.""" - no_correction = 0 - if ( - row_max_new - row_max - ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: - no_correction = 1 - row_max_new = row_max - - # pad for 4x32b - corr_layout = cute.make_layout( - (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), - stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), - ) - tCor = cute.make_tensor( - common_params.tmem_ptr + self.correction_factor_offset, - corr_layout, - ) - cCor = cute.make_identity_tensor(tCor.shape) - corr_tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype - ) - corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) - corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) - cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) - tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) - rCor = cute.make_fragment_like( - cCor_for_copy[None, None, None, 0], self.acc_dtype - ) - rCor_int = cute.make_tensor( - cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout - ) - rCor[0] = row_sum - rCor[1] = row_max_new - rCor[2] = correction_factor - rCor_int[3] = no_correction - - cute.copy( - corr_tmem_store_tiled_copy, - rCor, - tCor_for_copy[None, None, None, p_cor_producer_state.index], - ) - # fence between tmem store and correction warp - cute.arch.fence_view_async_tmem_store() - common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) - p_cor_producer_state.advance() - return p_cor_producer_state, row_max_new - - @cute.jit - def softmax( - self, - common_params: SimpleNamespace, - softmax_params: SimpleNamespace, - k_index: cutlass.Int32, - mma_s_consumer_state: pipeline.PipelineState, - p_mma_producer_state: pipeline.PipelineState, - p_cor_producer_state: pipeline.PipelineState, - row_max: cutlass.Float32, - row_sum: cutlass.Float32, - correction_factor: cutlass.Float32, - is_last_tile: bool, - is_local_last_tile: cutlass.Boolean, - ) -> tuple[ - pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, - cutlass.Float32, - cutlass.Float32, - cutlass.Float32, - ]: - """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param softmax_params: The softmax parameters - :type softmax_params: SimpleNamespace - :param k_index: The index of the k-tile - :type k_index: cutlass.Int32 - :param mma_s_consumer_state: The MMA s consumer state - :type mma_s_consumer_state: pipeline.PipelineState - :param p_mma_producer_state: The P MMA producer state - :type p_mma_producer_state: pipeline.PipelineState - :param p_cor_producer_state: The P correction producer state - :type p_cor_producer_state: pipeline.PipelineState - :param row_max: The row max - :type row_max: cutlass.Float32 - :param row_sum: The row sum - :type row_sum: cutlass.Float32 - :param correction_factor: The correction factor - :type correction_factor: cutlass.Float32 - :param is_last_tile: Whether the last tile - :type is_last_tile: bool - :param is_local_last_tile: Whether the last tile is local - :type is_local_last_tile: cutlass.Boolean - - :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor - :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] - """ - - softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) - softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) - - # load S from tmem - tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( - cute.select(self.mma_qk_tiler, mode=[0, 1]) - ) - tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( - cute.append(tStS_shape, self.mma_s_stage) - ) - tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) - tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] - - tAcc = tStS[(None, None), 0, 0] - cta_qk_tiler = ( - self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], - self.mma_qk_tiler[1], - self.mma_qk_tiler[2], - ) - cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype - ) - tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) - - tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - - tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) - tTR_tAcc = tmem_thr_copy.partition_S(tAcc) - tTR_tS = tmem_thr_copy.partition_D(cS) - - tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) - - row_max_new = row_max - arch = BaseDSL._get_dsl().get_arch_enum() - if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): - cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) - for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): - if is_last_tile: - tTR_rAcc[i] = ( - tTR_rAcc[i] - if cute.elem_less( - tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, - ) - else -self.acc_dtype.inf - ) - # reduction for row_max - row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): - tmem_load_red_atom = cute.make_copy_atom( - tcgen05.copy.LdRed32x32bOp( - tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX - ), - self.acc_dtype, - ) - tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) - tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) - tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) - tTR_tS_red = tmem_red_thr_copy.partition_D(cS) - tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) - tTR_rMax = cute.make_rmem_tensor( - cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), - self.acc_dtype, - ) - cute.copy( - tmem_red_tiled_copy, - tTR_tAcc_red, - (tTR_rAcc_red, tTR_rMax), - ) - tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) - if is_last_tile: - for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): - tTR_rAcc[i] = ( - tTR_rAcc[i] - if cute.elem_less( - tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, - ) - else -self.acc_dtype.inf - ) - # reduction for row_max - row_max_new = tTR_rAcc.load().reduce( - cute.ReductionOp.MAX, row_max_new, 0 - ) - else: - row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) - - # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) - if cutlass.const_expr(self.warps_in_n == 2): - common_params.smem_exchange[tidx] = row_max_new - self.softmax_exchange_sync_bar.wait() - row_max_new = cute.arch.fmax( - row_max_new, - common_params.smem_exchange[ - (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) - ], - ) - - # find correction factor - correction_factor = cute.math.exp2( - (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True - ) - # split kv case - if cutlass.const_expr(not is_local_last_tile): - p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( - common_params, - softmax_params, - correction_factor, - row_sum, - row_max, - row_max_new, - tAcc, - tidx, - p_cor_producer_state, - ) - - # softmax - fma_b = softmax_params.softmax_scale_log2 - fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 - - for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): - tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c - tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) - - tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) - - # quantize - tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) - - # create sP - sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] - sP_mk_view = cute.make_tensor( - sP.iterator, - cute.make_layout( - ( - (sP.shape[0][0], sP.shape[1]), - (sP.shape[0][1], sP.shape[2], sP.shape[3]), - ), - stride=( - (sP.stride[0][0], sP.stride[1]), - (sP.stride[0][1], sP.stride[2], sP.stride[3]), - ), - ), - ) - # {$nv-internal-release begin} - # TODO: figure out if we could use A tmem for pv. - # {$nv-internal-release end} - # change to PISL - sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) - swizzle_bits = ( - int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 - ) - swizzle_base = 3 if self.q_dtype.width == 16 else 4 - sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) - sP_mk_view = cute.make_tensor( - sP_wo_swizzle_iter, - cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), - ) - universal_copy_bits = 128 - smem_copy_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=universal_copy_bits, - ) - smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) - smem_thr_copy = smem_tiled_copy.get_slice(tidx) - rP_copy_view = smem_thr_copy.retile(tTR_rS) - sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) - cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) - - # fence between smem store and mma o - cute.arch.fence_view_async_shared() - softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) - p_mma_producer_state.advance() - - # row_sum, using `add_packed_f32x2` to reduce the number of instructions - row_sum = row_sum * correction_factor - row_sum_vec = (0.0, 0.0) - for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): - row_sum_vec = cute.arch.add_packed_f32x2( - row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) - ) - row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum - - # split kv case - if cutlass.const_expr(is_local_last_tile): - p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( - common_params, - softmax_params, - correction_factor, - row_sum, - row_max, - row_max_new, - tAcc, - tidx, - p_cor_producer_state, - ) - - # store correction factor/row_sum/row_max to tmem for correction warp - common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - - # fence between tmem load and mma s - cute.arch.fence_view_async_tmem_load() - - softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) - mma_s_consumer_state.advance() - - return ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max_new, - row_sum, - correction_factor, - ) - - @cute.jit - def _tmem_load_partition( - self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int - ) -> tuple[ - cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma - ]: - """Tensor memory load partition for rescale and epilogue. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param iter_n: The iteration number - :type iter_n: int - - :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv - :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] - """ - - tOtO_shape = tiled_mma_pv.partition_shape_C( - cute.select(self.mma_pv_tiler, mode=[0, 1]) - ) - tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) - tOtO_layout = cute.append( - tOtO.layout, - cute.make_layout( - common_params.L // self.mma_pv_tiler[1], - stride=self.mma_pv_tiler[1] // self.warps_in_n, - ), - ) - tOtO = cute.make_tensor( - common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout - ) - tOtO = tOtO[None, None, None, iter_n] - - tAcc = tOtO[(None, None), 0, 0] - - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype - ) - tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) - # {$nv-internal-release begin} - # TODO: supports size() on tiled copy. - # {$nv-internal-release end} - tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( - common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - ) - - cta_pv_tiler = ( - self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], - self.mma_pv_tiler[1], - self.mma_pv_tiler[2], - ) - # Flatten divide and partition global tensors for O - cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) - - gO = None - if cutlass.const_expr(common_params.mAccO is not None): - gO = cute.local_tile( - common_params.mAccO[None, common_params.blk_coord[3], None, None, None], - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - cO = cute.local_tile( - cute.make_identity_tensor( - common_params.mAccO[ - None, common_params.blk_coord[3], None, None, None - ].shape - ), - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - else: - gO = cute.local_tile( - common_params.mO, - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - cO = cute.local_tile( - cute.make_identity_tensor(common_params.mO.shape), - cta_pv_tiler_mn, - ( - common_params.blk_coord[0], - iter_n, - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - ) - tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) - tTR_gO = tmem_load_thr_copy.partition_D(gO) - tTR_cO = tmem_load_thr_copy.partition_D(cO) - tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) - return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] - - def get_correction_factor( - self, - common_params: SimpleNamespace, - p_cor_consumer_state: pipeline.PipelineState, - ) -> tuple[ - pipeline.PipelineState, - cutlass.Float32, - cutlass.Float32, - cutlass.Float32, - cutlass.Int32, - ]: - """Get the correction factor from the P correction consumer state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param p_cor_consumer_state: The P correction consumer state - :type p_cor_consumer_state: pipeline.PipelineState - - :return: The P correction consumer state, the row_sum, the row_max, and the correction factor - :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] - """ - common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) - tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - # load correction factor - _, tAcc, _, _, _, _ = self._tmem_load_partition( - common_params, common_params.tiled_mma_pv, 0 - ) - corr_layout = cute.make_layout( - (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), - stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), - ) - tCor = cute.make_tensor( - common_params.tmem_ptr + self.correction_factor_offset, corr_layout - ) - cCor = cute.make_identity_tensor(tCor.shape) - corr_tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype - ) - corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) - corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) - tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) - cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) - rCor = cute.make_fragment_like( - cCor_for_copy[None, None, None, 0], self.acc_dtype - ) - rCor_int = cute.make_tensor( - cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout - ) - cute.copy( - corr_tmem_load_tiled_copy, - tCor_for_copy[None, None, None, p_cor_consumer_state.index], - rCor, - ) - row_sum = rCor[0] - row_max = rCor[1] - correction_factor = rCor[2] - no_correction = rCor_int[3] - - common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) - p_cor_consumer_state.advance() - return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction - - @cute.jit - def rescale( - self, - common_params: SimpleNamespace, - mma_o_consumer_state: pipeline.PipelineState, - correction_factor: cutlass.Float32, - no_correction: cutlass.Int32, - ) -> pipeline.PipelineState: - """Rescale for one k-tile. Updates the related pipeline state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param mma_o_consumer_state: The mma o consumer state - :type mma_o_consumer_state: pipeline.PipelineState - :param correction_factor: The correction factor - :type correction_factor: cutlass.Float32 - :param no_correction: Whether to apply correction factor - :type no_correction: cutlass.Int32 - - :return: The MMA o consumer state - :rtype: pipeline.PipelineState - """ - skip_correction = cute.arch.vote_all_sync(no_correction == 1) - for iter_n in cutlass.range_constexpr(self.iterations_pv_n): - common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) - if not skip_correction: - # tmem load tiled copy and partition results. - tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( - self._tmem_load_partition( - common_params, common_params.tiled_mma_pv, iter_n - ) - ) - - # tmem store tiled copy - tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype - ) - tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) - - # load o - cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) - # rescale, using `mul_packed_f32x2` to reduce the number of instructions - for i in cutlass.range( - cute.size(tTR_rAcc), vectorize=True, unroll_full=True - ): - tTR_rAcc[i] = tTR_rAcc[i] * correction_factor - - # store o to tensor memory for next k tile - cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) - - cute.arch.fence_view_async_tmem_store() - common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) - mma_o_consumer_state.advance() - - return mma_o_consumer_state - - @cute.jit - def epilogue( - self, - common_params: SimpleNamespace, - epilogue_params: SimpleNamespace, - mma_o_consumer_state: pipeline.PipelineState, - row_sum: cutlass.Float32, - row_max: cutlass.Float32, - ) -> pipeline.PipelineState: - """Epilogue for one k-tile. Updates the related pipeline state. - - :param common_params: The common parameters - :type common_params: SimpleNamespace - :param epilogue_params: The epilogue parameters - :type epilogue_params: SimpleNamespace - :param mma_o_consumer_state: The mma o consumer state - :type mma_o_consumer_state: pipeline.PipelineState - :param row_sum: The row sum - :type row_sum: cutlass.Float32 - :param row_max: The row max - :type row_max: cutlass.Float32 - - :return: The MMA o consumer state - :rtype: pipeline.PipelineState - """ - - tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) - - # exchange row_sum between warps (0, 1) and (2, 3) - if cutlass.const_expr(self.warps_in_n == 2): - common_params.smem_exchange[tidx] = row_sum - self.epilogue_exchange_sync_bar.wait() - # (64, 2) - row_sum = ( - row_sum - + common_params.smem_exchange[ - (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) - ] - ) - # mma_o pipeline consumer wait - for iter_n in cutlass.range_constexpr(self.iterations_pv_n): - common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) - # tmem load tiled copy and partition results. - tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( - self._tmem_load_partition( - common_params, common_params.tiled_mma_pv, iter_n - ) - ) - - # load o - cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) - - # apply output scale and normalize by row_sum - for i in cutlass.range( - cute.size(tTR_rAcc), vectorize=True, unroll_full=True - ): - tTR_rAcc[i] = ( - tTR_rAcc[i] - * epilogue_params.output_scale - * cute.arch.rcp_approx(row_sum) - ) - - # store o to global memory - tR2G_rO_src = None - tR2G_rO_dst = tTR_gO - if cutlass.const_expr(common_params.mAccO is None): - tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) - # using final output dtype for o - tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) - else: - # using accumulate dtype for o - tR2G_rO_src = tTR_rAcc - - if cute.elem_less(tTR_cO[0][0], common_params.H): - cute.autovec_copy( - tR2G_rO_src, - tR2G_rO_dst, - l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, - ) - - # store the lse to global memory - cta_pv_tiler = ( - self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], - self.mma_pv_tiler[1], - self.mma_pv_tiler[2], - ) - gLSE = None - cLSE = None - if cutlass.const_expr(epilogue_params.mAccLSE is None): - gLSE = cute.local_tile( - epilogue_params.mLSE, - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - cLSE = cute.local_tile( - cute.make_identity_tensor(epilogue_params.mLSE.shape), - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - - else: - gLSE = cute.local_tile( - epilogue_params.mAccLSE[ - None, common_params.blk_coord[3], None, None - ], - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - cLSE = cute.local_tile( - cute.make_identity_tensor( - epilogue_params.mAccLSE[ - None, common_params.blk_coord[3], None, None - ].shape - ), - (cta_pv_tiler[0], 1, 1), - ( - common_params.blk_coord[0], - common_params.blk_coord[1], - common_params.blk_coord[2], - ), - (1, 1, 1), - ) - lse = ( - cute.math.log2(row_sum, fastmath=True) - + epilogue_params.softmax_scale_log2 * row_max - ) - if cutlass.const_expr(self.warps_in_n == 2): - if cute.elem_less(cLSE[tidx][0], common_params.H): - gLSE[tidx] = lse - - cute.arch.fence_view_async_tmem_load() - common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) - mma_o_consumer_state.advance() - - return mma_o_consumer_state - - def make_and_init_load_qkv_pipeline( - self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count - ) -> pipeline.PipelineTmaUmma: - """Create and initialize the tma load qkv pipeline. - - :param load_qkv_mbar_ptr: The load qkv mbar pointer - :type load_qkv_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - :param load_stages: The load stages - :type load_stages: list[int] - :param tx_count: The tx count - :type tx_count: int - - :return: The tma load qkv pipeline - :rtype: pipeline.PipelineTmaUmma - """ - load_qkv_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.load_tma_k_warp_id]) - ) - load_qkv_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - return pipeline.PipelineTmaUmma.create( - barrier_storage=load_qkv_mbar_ptr, - num_stages=load_stages, - producer_group=load_qkv_producer_group, - consumer_group=load_qkv_consumer_group, - tx_count=tx_count, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - def make_and_init_mma_s_pipeline( - self, mma_s_mbar_ptr, cta_layout_vmnk - ) -> pipeline.PipelineUmmaAsync: - """Create and initialize the mma s pipeline. - - :param mma_s_mbar_ptr: The mma s mbar pointer - :type mma_s_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - - :return: The mma s pipeline - :rtype: pipeline.PipelineUmmaAsync - """ - - mma_s_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - consumer_thread_size = ( - self.threads_per_warp - * len(self.compute_warp_ids) - * self.cluster_shape_mnk[0] - ) - mma_s_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - consumer_thread_size, - ) - return pipeline.PipelineUmmaAsync.create( - barrier_storage=mma_s_mbar_ptr, - num_stages=self.mma_s_stage, - producer_group=mma_s_producer_group, - consumer_group=mma_s_consumer_group, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - def make_and_init_p_mma_pipeline( - self, p_mma_mbar_ptr, cta_layout_vmnk - ) -> pipeline.PipelineAsyncUmma: - """Create and initialize the p mma pipeline. - - :param p_mma_mbar_ptr: The p mma mbar pointer - :type p_mma_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - - :return: The p mma pipeline - :rtype: pipeline.PipelineAsyncUmma - """ - - producer_thread_size = ( - self.threads_per_warp - * len(self.compute_warp_ids) - * self.cluster_shape_mnk[0] - ) - p_mma_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - producer_thread_size, - ) - p_mma_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - return pipeline.PipelineAsyncUmma.create( - barrier_storage=p_mma_mbar_ptr, - num_stages=self.p_mma_stage, - producer_group=p_mma_producer_group, - consumer_group=p_mma_consumer_group, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - def make_and_init_p_cor_pipeline( - self, p_cor_mbar_ptr - ) -> pipeline.PipelineAsyncUmma: - """Create and initialize the p correction pipeline. - - :param p_cor_mbar_ptr: The p correction mbar pointer - :type p_cor_mbar_ptr: cute.Tensor - - :return: The p correction pipeline - :rtype: pipeline.PipelineAsyncUmma - """ - - producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) - p_cor_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - producer_thread_size, - ) - p_cor_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - producer_thread_size, - ) - return pipeline.PipelineAsync.create( - barrier_storage=p_cor_mbar_ptr, - num_stages=self.p_cor_stage, - producer_group=p_cor_producer_group, - consumer_group=p_cor_consumer_group, - defer_sync=True, - ) - - def make_and_init_mma_o_pipeline( - self, mma_o_mbar_ptr, cta_layout_vmnk - ) -> pipeline.PipelineUmmaAsync: - """Create and initialize the mma o pipeline. - - :param mma_o_mbar_ptr: The mma o mbar pointer - :type mma_o_mbar_ptr: cute.Tensor - :param cta_layout_vmnk: The cta layout vmnk - :type cta_layout_vmnk: tuple[int, int, int] - - :return: The mma o pipeline - :rtype: pipeline.PipelineUmmaAsync - """ - - mma_o_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) - ) - consumer_thread_size = ( - self.threads_per_warp - * len(self.compute_warp_ids) - * self.cluster_shape_mnk[0] - ) - mma_o_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - consumer_thread_size, - ) - return pipeline.PipelineUmmaAsync.create( - barrier_storage=mma_o_mbar_ptr, - num_stages=self.mma_o_stage, - producer_group=mma_o_producer_group, - consumer_group=mma_o_consumer_group, - cta_layout_vmnk=cta_layout_vmnk, - defer_sync=True, - ) - - @staticmethod - def _compute_grid( - o: cute.Tensor, - split_kv: cutlass.Int32, - cluster_shape_mnk: Tuple[int, int, int], - max_active_clusters: int, - is_persistent: bool, - ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: - """Compute grid shape for the output tensor C. - - :param c: The output tensor C - :type c: cute.Tensor - :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. - :type cta_tile_shape_mnk: tuple[int, int, int] - :param cluster_shape_mn: Shape of each cluster in M, N dimensions. - :type cluster_shape_mn: tuple[int, int] - - :return: Tile scheduler parameters and grid shape. - :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] - """ - o_shape = o.shape - tile_sched_params = create_mla_static_tile_scheduler_params( - is_persistent, - cute.size(o_shape[3]), - cute.size(o_shape[2]), - cluster_shape_mnk, - split_kv, - ) - grid = MLAStaticTileScheduler.get_grid_shape( - tile_sched_params, max_active_clusters - ) - - return tile_sched_params, grid - - @staticmethod - def get_workspace_size( - H: int, - S: int, - D: int, - B: int, - split_kv: int, - acc_dtype: Type[cutlass.Numeric], - ) -> int: - """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. - - :param H: The height of the output tensor C - :type H: int - :param S: The sequence length of the output tensor C - :type S: int - :param D: The depth of the output tensor C - :type D: int - :param B: The batch size of the output tensor C - :type B: int - :param split_kv: The split key-value of the output tensor C - :type split_kv: int - :param acc_dtype: The data type of the output tensor C - :type acc_dtype: Type[cutlass.Numeric] - - :return: The workspace size for the MLA kernel - :rtype: int - """ - if split_kv == 1: - return 0 - return B * H * S * split_kv * (D + 1) * acc_dtype.width // 8 - - @cute.jit - def initialize_workspace( - self, - H: cutlass.Int32, - D: cutlass.Int32, - S: cutlass.Int32, - B: cutlass.Int32, - split_kv: cutlass.Int32, - acc_dtype: Type[cutlass.Numeric], - workspace: cute.Tensor, - ) -> tuple[cute.Tensor, cute.Tensor]: - """Initialize the workspace for the MLA kernel. Construct the intermediate tensors - acc_o and acc_lse. - - :param H: The height of the output tensor C - :type H: cutlass.Int32 - :param D: The depth of the output tensor C - :type D: cutlass.Int32 - :param S: The sequence length of the output tensor C - :type S: cutlass.Int32 - :param B: The batch size of the output tensor C - :type B: cutlass.Int32 - :param split_kv: The split key-value of the output tensor C - :type split_kv: cutlass.Int32 - :param acc_dtype: The data type of the output tensor C - :type acc_dtype: Type[cutlass.Numeric] - :param workspace: The workspace tensor - :type workspace: cute.Tensor - - :return: The output tensor C and the workspace tensor - :rtype: tuple[cute.Tensor, cute.Tensor] - """ - acc_o, acc_lse = None, None - if cutlass.const_expr(workspace is not None): - align = 256 // self.q_dtype.width - acc_o_layout = cute.make_layout( - (H, split_kv, D, S, B), - stride=( - cute.assume(split_kv * D, align), - cute.assume(D, align), - 1, - cute.assume(split_kv * H * D, align), - cute.assume(H * split_kv * S * D, align), - ), - ) - acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) - acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) - acc_lse_layout = cute.make_layout( - (H, split_kv, S, B), - stride=(split_kv, 1, H * split_kv, H * split_kv * S), - ) - acc_lse_iter = cute.recast_ptr( - workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, - dtype=acc_dtype, - ) - acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) - return acc_o, acc_lse - - @staticmethod - def can_implement( - B: int, - S: int, - K: int, - H: int, - L: int, - R: int, - in_dtype: Type[cutlass.Numeric], - out_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - lse_dtype: Type[cutlass.Numeric], - mma_qk_tiler_mn: Tuple[int, int], - mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - page_size: int, - ) -> bool: - """Check if the MLA kernel can be implemented. - - :param B: The batch size of the output tensor C - :type B: int - :param S: The sequence length of the output tensor C - :type S: int - :param K: The width of the output tensor KV - :type K: int - :param H: The number of heads of the output tensor C - :type H: int - :param L: The number of latent dimensions of the tensor KV - :type L: int - :param R: The number of rope dimensions of the tensor C_rope - :type R: int - :param in_dtype: The data type of the input tensor - :type in_dtype: Type[cutlass.Numeric] - :param out_dtype: The data type of the output tensor - :type out_dtype: Type[cutlass.Numeric] - :param acc_dtype: The data type of the accumulator - :type acc_dtype: Type[cutlass.Numeric] - :param lse_dtype: The data type of the log-sum-exp - :type lse_dtype: Type[cutlass.Numeric] - :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication - :type mma_qk_tiler_mn: Tuple[int, int] - :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication - :type mma_pv_tiler_mn: Tuple[int, int] - :param split_kv: The split key-value of the output tensor C - :type split_kv: int - :param is_persistent: Whether to use persistent kernel optimization - :type is_persistent: bool - :param is_var_seq: Whether to use variable sequence length - :type is_var_seq: bool - :param is_var_split_kv: Whether to use variable split_kv - :type is_var_split_kv: bool - :param page_size: The page size of the page table - :type page_size: int - - :return: Whether the MLA kernel can be implemented - :rtype: bool - """ - if L != 512 or R != 64: - return False - if in_dtype not in [cutlass.Float8E4M3FN]: - return False - if out_dtype not in [cutlass.Float8E4M3FN, cutlass.BFloat16]: - return False - if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: - return False - # page size equals 1 is prohibited by tma specification, not 128B aligned. - if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: - return False - if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: - return False - if is_var_split_kv and not is_var_seq: - return False - if H > 128 or (H < 128 and split_kv != 1): - return False - if S <= 0 or S > 4: - return False - if K <= 0: - return False - return True - - -def run( - batch_size: int, - seq_len_q: int, - seq_len_k: int, - num_heads: int, - latent_dim: int, - rope_dim: int, - in_dtype: Type[cutlass.Numeric], - out_dtype: Type[cutlass.Numeric], - acc_dtype: Type[cutlass.Numeric], - lse_dtype: Type[cutlass.Numeric], - mma_qk_tiler_mn: Tuple[int, int], - mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, - is_persistent: bool, - is_var_seq: bool, - is_var_split_kv: bool, - page_size: int, - softmax_scale: float, - output_scale: float, - skip_correction_threshold: float, - tolerance: float, - warmup_iterations: int, - iterations: int, - skip_ref_check: bool, - use_cold_l2: bool, - enable_pdl: bool = False, - **kwargs, -): - """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. - - This function creates random input tensors for query latent/rope, compressed latent/rope, and value, - then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, - page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference - implementation or run multiple times for performance measurement. - - :param batch_size: Batch size - :type batch_size: int - :param seq_len_q: Sequence length of Q - :type seq_len_q: int - :param seq_len_k: Sequence length of K - :type seq_len_k: int - :param num_heads: Number of heads - :type num_heads: int - :param latent_dim: dimension of query/compressed latent - :type latent_dim: int - :param rope_dim: dimension of query/compressed rope - :type rope_dim: int - :param in_dtype: Input data type for query/compressed latent/rope tensors - :type in_dtype: Type[cutlass.Numeric] - :param out_dtype: Output data type for attention output - :type out_dtype: Type[cutlass.Numeric] - :param acc_dtype: Accumulator data type for query-key matrix multiplication - :type acc_dtype: Type[cutlass.Numeric] - :param lse_dtype: Accumulator data type for log-sum-exp - :type lse_dtype: Type[cutlass.Numeric] - :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication - :type mma_qk_tiler_mn: Tuple[int, int] - :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication - :type mma_pv_tiler_mn: Tuple[int, int] - :param split_kv: Split key-value - :type split_kv: int - :param is_persistent: Whether to use persistent kernel optimization - :type is_persistent: bool - :param is_var_seq: Whether to use variable sequence length - :type is_var_seq: bool - :param is_var_split_kv: Whether to use variable split_kv - :type is_var_split_kv: bool - :param page_size: Page size of the page table - :type page_size: int - :param softmax_scale: Attention score scaling factor - :type softmax_scale: float - :param output_scale: Output scaling factor - :type output_scale: float - :param skip_correction_threshold: Threshold to skip correction - :type skip_correction_threshold: float - :param tolerance: Maximum acceptable error for validation - :type tolerance: float - :param warmup_iterations: Number of warmup iterations - :type warmup_iterations: int - :param iterations: Number of iterations to run for performance testing - :type iterations: int - :param skip_ref_check: Skip validation against reference implementation - :type skip_ref_check: bool - :param use_cold_l2: Whether to use cold L2 cache - :type use_cold_l2: bool - - :raises ValueError: If input shapes are incompatible or head dimension is unsupported - :raises RuntimeError: If GPU is unavailable for computation - """ - - print("Running Blackwell MLA test with:") - print(f" batch_size: {batch_size}") - print(f" seq_len_q: {seq_len_q}") - print(f" seq_len_k: {seq_len_k}") - print(f" num_heads: {num_heads}") - print(f" latent_dim: {latent_dim}") - print(f" rope_dim: {rope_dim}") - print(f" in_dtype: {in_dtype}") - print(f" out_dtype: {out_dtype}") - print(f" acc_dtype: {acc_dtype}") - print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") - print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") - print(f" split_kv: {split_kv}") - print(f" is_persistent: {is_persistent}") - print(f" is_var_seq: {is_var_seq}") - print(f" is_var_split_kv: {is_var_split_kv}") - print(f" page_size: {page_size}") - print(f" softmax_scale: {softmax_scale}") - print(f" output_scale: {output_scale}") - print(f" skip_correction_threshold: {skip_correction_threshold}") - print(f" tolerance: {tolerance}") - print(f" warmup_iterations: {warmup_iterations}") - print(f" iterations: {iterations}") - print(f" skip_ref_check: {skip_ref_check}") - print(f" use_cold_l2: {use_cold_l2}") - - import torch - import cutlass.torch as cutlass_torch - - # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) - if not torch.cuda.is_available(): - raise RuntimeError("GPU is required to run this example!") - - if not BlackwellMultiHeadLatentAttentionForwardFP8.can_implement( - batch_size, - seq_len_q, - seq_len_k, - num_heads, - latent_dim, - rope_dim, - in_dtype, - out_dtype, - acc_dtype, - lse_dtype, - mma_qk_tiler_mn, - mma_pv_tiler_mn, - split_kv, - is_persistent, - is_var_seq, - is_var_split_kv, - page_size, - ): - raise TypeError( - f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" - ) - - torch.manual_seed(1111) - - def create_data_tensor( - B, - HK, - D, - dtype, - is_dynamic_layout=True, - page_table=None, - cache_seqs=None, - is_lse=False, - seq_len_q=None, - ): - shape = (B, HK, D) - if page_table is not None: - if cache_seqs is not None: - max_seq_len = torch.max(cache_seqs) - shape = (B * ceil_div(max_seq_len, page_size), page_size, D) - else: - shape = (B * ceil_div(HK, page_size), page_size, D) - - if seq_len_q is not None: - shape = (B, seq_len_q, HK, D) - - # Contiguous row-major: last dim has stride 1 (highest stride_order value = fastest) - if is_lse: - shape = (B, seq_len_q, HK) - leading_dim = 2 - stride_order = (0, 1, 2) - elif seq_len_q is not None: - leading_dim = 3 - stride_order = (0, 1, 2, 3) - else: - leading_dim = 2 - stride_order = (0, 1, 2) - - init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) - - torch_dtype = ( - cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 - ) - - # Create contiguous dtype torch tensor (cpu) — no permute - torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( - shape, - torch_dtype, - init_type=cutlass.torch.TensorInitType.RANDOM, - init_config=init_config, - ) - - # Create dtype torch tensor (gpu) - torch_tensor_gpu = torch_tensor_cpu.cuda() - - # Create f32 torch tensor (cpu) - f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) - - # Create dtype cute tensor (gpu) - cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) - cute_tensor.element_type = dtype - if is_dynamic_layout: - cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) - if not is_lse: - cute_tensor = cute_tensor.mark_compact_shape_dynamic( - mode=leading_dim, - stride_order=stride_order, - divisibility=(128 // dtype.width), - ) - - cute_tensor = cutlass_torch.convert_cute_tensor( - f32_torch_tensor, - cute_tensor, - dtype, - is_dynamic_layout=is_dynamic_layout, - ) - - return f32_torch_tensor, cute_tensor, torch_tensor_gpu - - def create_cache_seqs(batch_size, seq_len_k, is_var_seq): - cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k - cache_seqs_gpu = cache_seqs_ref.cuda() - cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() - if is_var_seq: - max_seq_len = seq_len_k - min_seq_len = int(seq_len_k * 0.8) - cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( - (batch_size,), - torch.int32, - init_type=cutlass.torch.TensorInitType.RANDOM, - init_config=cutlass.torch.RandomInitConfig( - min_val=min_seq_len, max_val=max_seq_len + 1 - ), - ) - cache_seqs_gpu = cache_seqs_ref.cuda() - cache_seqs = from_dlpack( - cache_seqs_gpu, - assumed_align=16, - ).mark_layout_dynamic() - return cache_seqs_ref, cache_seqs, cache_seqs_gpu - - def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): - max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) - page_count = ceil_div(max_seq_len, page_size) - page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) - # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. - for b in range(batch_size): - for j in range(page_count): - page_table_ref[b, j] = b + j * batch_size - page_table_gpu = page_table_ref.cuda() # contiguous [B, page_count] - page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( - leading_dim=1 - ) - return page_table_ref, page_table, page_table_gpu - - def create_block_split_kvs( - batch_size, - split_kv, - cache_seqs_ref, - is_var_split_kv, - mma_qk_tiler_mn, - cluster_shape_mnk, - max_active_clusters, - ): - block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None - # check if split_kv is valid otherwise do auto setting of split_kv - if is_var_split_kv: - block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) - for b in range(batch_size): - block_split_kvs_ref[b] = ( - BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( - batch_size, - seq_len_q, - cache_seqs_ref[b].item(), - mma_qk_tiler_mn, - max_active_clusters * cluster_shape_mnk[0], - ) - ) - split_kv = torch.max(block_split_kvs_ref).item() - block_split_kvs_gpu = block_split_kvs_ref.cuda() - block_split_kvs = from_dlpack( - block_split_kvs_gpu, assumed_align=16 - ).mark_layout_dynamic() - elif split_kv <= 0: - split_kv = BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( - batch_size, - seq_len_q, - cache_seqs_ref[0].item(), - mma_qk_tiler_mn, - max_active_clusters * cluster_shape_mnk[0], - ) - return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu - - def create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype - ): - workspace_size = BlackwellMultiHeadLatentAttentionForwardFP8.get_workspace_size( - num_heads, - seq_len_q, - latent_dim, - batch_size, - split_kv, - acc_dtype, - ) - - workspace, workspace_torch = None, None - if workspace_size > 0: - workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() - workspace = from_dlpack(workspace_torch, assumed_align=32) - return workspace, workspace_torch - - cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( - batch_size, seq_len_k, is_var_seq - ) - page_table_ref, page_table, page_table_torch = create_page_table( - batch_size, seq_len_k, is_var_seq, page_size - ) - cluster_shape_mnk = (2, 1, 1) - hardware_info = utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mnk[0] * cluster_shape_mnk[1] - ) - split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( - create_block_split_kvs( - batch_size, - split_kv, - cache_seqs_ref, - is_var_split_kv, - mma_qk_tiler_mn, - cluster_shape_mnk, - max_active_clusters, - ) - ) - - q_latent_ref, q_latent, q_latent_torch = create_data_tensor( - batch_size, - num_heads, - latent_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - q_rope_ref, q_rope, q_rope_torch = create_data_tensor( - batch_size, - num_heads, - rope_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - - c_latent_ref, c_latent, c_latent_torch = create_data_tensor( - batch_size, - seq_len_k, - latent_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - c_rope_ref, c_rope, c_rope_torch = create_data_tensor( - batch_size, - seq_len_k, - rope_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - o_ref, o, o_torch = create_data_tensor( - batch_size, - num_heads, - latent_dim, - out_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - lse_ref, lse, lse_torch = create_data_tensor( - batch_size, - num_heads, - 1, - lse_dtype, - is_dynamic_layout=True, - is_lse=True, - seq_len_q=seq_len_q, - ) - workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype - ) - - mla = BlackwellMultiHeadLatentAttentionForwardFP8( - acc_dtype, - lse_dtype, - mma_qk_tiler_mn, - mma_pv_tiler_mn, - max_active_clusters, - page_size, - skip_correction_threshold, - is_persistent, - is_var_seq, - is_var_split_kv, - enable_pdl, - ) - - # Get current CUDA stream from PyTorch - torch_stream = torch.cuda.current_stream() - # Get the raw stream pointer as a CUstream - stream = cuda.CUstream(torch_stream.cuda_stream) - - # compile mla kernel - compiled_mla = cute.compile( - mla, - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - split_kv, - cache_seqs, - block_split_kvs, - softmax_scale, - output_scale, - stream, - options="--opt-level 2", - ) - - def torch_reference_mla( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - cache_seqs, - softmax_scale=1.0, - output_scale=1.0, - ): - # Ref tensors are now contiguous: - # q_latent/q_rope: [B, S_q, H, D] - # c_latent/c_rope: [num_pages, page_size, D] - # Concat along last dim and reshape for SDPA [B, S_q, H, D_total] - q_ref = torch.cat([q_latent, q_rope], dim=3) - # KV cache: concat along last dim, already [num_pages, page_size, D_total] - page_count = page_table_ref.shape[1] - k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( - batch_size * page_count, page_size, latent_dim + rope_dim - ) - v_ref_paged = c_latent.reshape(batch_size * page_count, page_size, latent_dim) - - if is_var_seq: - max_seq_len = torch.max(cache_seqs_ref) - else: - max_seq_len = seq_len_k - - k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) - v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) - k_ref = torch.index_select( - k_ref_paged, 0, torch.flatten(page_table_ref) - ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] - v_ref = torch.index_select( - v_ref_paged, 0, torch.flatten(page_table_ref) - ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] - for b in range(batch_size): - k_ref[b, :, cache_seqs_ref[b] :, :] = 0 - v_ref[b, :, cache_seqs_ref[b] :, :] = 0 - import torch.nn.functional as F - - o_ref = F.scaled_dot_product_attention( - q_ref, - k_ref, - v_ref, - attn_mask=None, - dropout_p=0.0, - scale=softmax_scale, - is_causal=False, - ) - s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) - s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) - softmax_scale_log2 = LOG2_E * softmax_scale - s_ref_sum = torch.sum( - torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True - ) - - lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) - lse_ref = lse_ref.squeeze(3) # [B, S_q, H] - o_ref = o_ref * output_scale - # o_ref already [B, S_q, H, D_latent] — matches contiguous output layout - - return o_ref, lse_ref - - if skip_correction_threshold > 0.0: - print( - "Skipping correction verification since skip_correction_threshold is greater than 0.0..." - ) - skip_ref_check = True - if not skip_ref_check: - # Execute kernel once for reference checking - compiled_mla( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - split_kv, - cache_seqs, - block_split_kvs, - softmax_scale, - output_scale, - stream, - ) - torch.cuda.synchronize() - - print("Verifying results...") - if in_dtype == cutlass.Float8E4M3FN: - tolerance = 0.13 - o_ref, lse_ref = torch_reference_mla( - q_latent_ref, - q_rope_ref, - c_latent_ref, - c_rope_ref, - page_table, - cache_seqs, - softmax_scale, - output_scale, - ) - - if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: - # {$nv-internal-release begin} - # todo: not sure why, but the below `cute.testing.convert` will cause bus error occasionally in local and ci. - # {$nv-internal-release end} - # convert o back to f32 for comparison - o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( - torch.empty(*o_torch.shape, dtype=torch.float32), - cutlass.Float32, - is_dynamic_layout=True, - assumed_align=16, - ) - cute.testing.convert(o, o_fp32) - o = o_fp32_torch.cpu() - ref_fp8, _ = cutlass_torch.cute_tensor_like( - torch.empty(*o_ref.shape, dtype=torch.uint8), - out_dtype, - is_dynamic_layout=True, - assumed_align=16, - ) - o_ref_gpu = o_ref.cuda() - o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=3) - - # convert ref : f32 -> fp8 -> f32 - cute.testing.convert(o_ref_f32, ref_fp8) - cute.testing.convert(ref_fp8, o_ref_f32) - - o_ref = o_ref_gpu.cpu() - else: - o = o_torch.cpu().to(torch.float32) - lse = lse_torch.cpu() - lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) - # Assert close results - torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) - torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) - print("Results verified successfully!") - - def generate_tensors(): - _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) - _, page_table, _ = create_page_table( - batch_size, seq_len_k, is_var_seq, page_size - ) - _split_kv, _, block_split_kvs, _ = create_block_split_kvs( - batch_size, - split_kv, - cache_seqs_ref, - is_var_split_kv, - mma_qk_tiler_mn, - cluster_shape_mnk, - max_active_clusters, - ) - - _, q_latent, _ = create_data_tensor( - batch_size, - num_heads, - latent_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - _, q_rope, _ = create_data_tensor( - batch_size, - num_heads, - rope_dim, - in_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - - _, c_latent, _ = create_data_tensor( - batch_size, - seq_len_k, - latent_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - _, c_rope, _ = create_data_tensor( - batch_size, - seq_len_k, - rope_dim, - in_dtype, - is_dynamic_layout=True, - page_table=page_table, - cache_seqs=cache_seqs_ref, - ) - _, o, _ = create_data_tensor( - batch_size, - num_heads, - latent_dim, - out_dtype, - is_dynamic_layout=True, - seq_len_q=seq_len_q, - ) - _, lse, _ = create_data_tensor( - batch_size, - num_heads, - 1, - lse_dtype, - is_dynamic_layout=True, - is_lse=True, - seq_len_q=seq_len_q, - ) - workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype - ) - return testing.JitArguments( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - _split_kv, - cache_seqs, - block_split_kvs, - softmax_scale, - output_scale, - stream, - ) - - workspace_count = 1 - if use_cold_l2: - one_workspace_bytes = ( - q_latent_torch.numel() * q_latent_torch.element_size() - + q_rope_torch.numel() * q_rope_torch.element_size() - + c_latent_torch.numel() * c_latent_torch.element_size() - + c_rope_torch.numel() * c_rope_torch.element_size() - + o_torch.numel() * o_torch.element_size() - + lse_torch.numel() * lse_torch.element_size() - + cache_seqs_torch.numel() * cache_seqs_torch.element_size() - ) - one_workspace_bytes += ( - page_table_torch.numel() * page_table_torch.element_size() - ) - if is_var_split_kv: - one_workspace_bytes += ( - block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() - ) - if workspace_torch is not None: - one_workspace_bytes += ( - workspace_torch.numel() * workspace_torch.element_size() - ) - workspace_count = testing.get_workspace_count( - one_workspace_bytes, warmup_iterations, iterations - ) - - avg_time_us = testing.benchmark( - compiled_mla, - workspace_generator=generate_tensors, - workspace_count=workspace_count, - stream=stream, - warmup_iterations=warmup_iterations, - iterations=iterations, - ) - - return avg_time_us # Return execution time in microseconds diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 52d6a43e09..08456c804b 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1548,7 +1548,13 @@ def __init__( """ _check_kv_layout(kv_layout) - if jit_args is not None: + if backend == "cute-dsl": + raise NotImplementedError( + "cute-dsl backend is not yet supported for paged KV cache. " + "Use BatchPrefillWithRaggedKVCacheWrapper instead." + ) + + if jit_args is not None and backend != "cute-dsl": if jit_kwargs is None: jit_kwargs = {} self._jit_module = get_batch_prefill_jit_module( @@ -2619,10 +2625,12 @@ def __init__( will be used in attention computation. backend : str - The implementation backend, could be ``auto``/``fa2``/``fa3``/``cudnn`` or ``cutlass``. + The implementation backend, could be ``auto``/``fa2``/``fa3``/``cudnn``/``cutlass`` + or ``cute-dsl``. Defaults to ``auto``. If set to ``auto``, the wrapper will automatically choose the backend based on the device architecture and kernel availability. + The ``cute-dsl`` backend uses the CuTe DSL attention kernel for Blackwell (SM100+). jit_args : Optional[List[Any]] If provided, the wrapper will use the provided arguments to create the JIT module, @@ -2632,7 +2640,7 @@ def __init__( The keyword arguments to create the JIT module, defaults to None. """ _check_kv_layout(kv_layout) - if jit_args is not None: + if jit_args is not None and backend != "cute-dsl": if jit_kwargs is None: jit_kwargs = {} self._jit_module = get_batch_prefill_jit_module( @@ -2642,6 +2650,14 @@ def __init__( else: self._jit_module = None + self._cute_dsl_wrapper = None + if backend == "cute-dsl": + from .cute_dsl.attention import BatchPrefillCuteDSLWrapper + + self._cute_dsl_wrapper = BatchPrefillCuteDSLWrapper( + float_workspace_buffer, use_cuda_graph=use_cuda_graph + ) + self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device @@ -2968,7 +2984,57 @@ def plan( self._max_token_per_sequence = max_token_per_sequence self._max_sequence_kv = max_sequence_kv - if self._jit_module is not None: + if self._backend == "cute-dsl": + if custom_mask is not None or packed_custom_mask is not None: + raise NotImplementedError( + "cute-dsl backend does not support custom_mask" + ) + if head_dim_vo is not None and head_dim_vo != head_dim_qk: + raise NotImplementedError( + "cute-dsl backend requires head_dim_vo == head_dim_qk" + ) + if self._kv_layout == "HND": + raise NotImplementedError("cute-dsl backend only supports NHD layout") + if pos_encoding_mode not in ("NONE", "ALIBI"): + raise NotImplementedError( + f"cute-dsl backend does not support pos_encoding_mode={pos_encoding_mode!r}. " + "For RoPE, apply rotary embeddings to Q/K before calling the kernel." + ) + + variant: Optional[Any] = None + if pos_encoding_mode == "ALIBI": + from .cute_dsl.attention import ALiBiAttention + + slopes = ALiBiAttention.get_slopes(num_qo_heads) + variant = ALiBiAttention( + torch.tensor(slopes, dtype=torch.float32, device=qo_indptr.device) + ) + if logits_soft_cap is not None and logits_soft_cap > 0: + from .cute_dsl.attention import SoftCappingAttention + + if variant is not None: + raise NotImplementedError( + "cute-dsl backend does not support combining ALiBi with logits_soft_cap" + ) + variant = SoftCappingAttention(cap=logits_soft_cap) + + self._cute_dsl_wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_qk, + causal=causal, + sm_scale=sm_scale + if sm_scale is not None + else 1.0 / math.sqrt(head_dim_qk), + q_data_type=q_data_type, + kv_data_type=kv_data_type if kv_data_type is not None else q_data_type, + window_left=window_left, + variant=variant, + ) + elif self._jit_module is not None: self._cached_module = self._jit_module else: if self._backend == "auto": @@ -3009,7 +3075,7 @@ def plan( self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal ) self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item() - elif self._backend != "cudnn": + elif self._backend not in ("cudnn", "cute-dsl"): assert self._cached_module is not None, "cached module is not initialized" args = [ self._float_workspace_buffer, @@ -3213,7 +3279,21 @@ def run( q.device, "out", ) - if self._backend == "cutlass": + if self._backend == "cute-dsl": + # These checks live here (not in plan()) because return_lse and + # scale parameters are run()-time arguments that can vary between + # calls on the same planned wrapper. + if return_lse: + raise NotImplementedError( + "cute-dsl backend does not support return_lse" + ) + if any(s is not None for s in (q_scale, k_scale, v_scale, o_scale)): + raise NotImplementedError( + "cute-dsl backend does not support FP8 scale parameters" + ) + out = self._cute_dsl_wrapper.run(q, k, v, out=out) + return out + elif self._backend == "cutlass": out, lse = fmha_varlen( q, k, diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 8a5e572a96..3abe867595 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -96,9 +96,9 @@ def torch_reference_mla( return torch.stack(outputs, dim=0) # [B, q_len, H, latent_dim] -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) -@pytest.mark.parametrize("page_size", [128]) +@pytest.mark.parametrize("batch_size", [1, 4, 32]) +@pytest.mark.parametrize("seq_len_k", [128, 512, 2048, 8192]) +@pytest.mark.parametrize("page_size", [32, 128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("q_len", [1, 2]) @pytest.mark.parametrize("enable_pdl", [True, False]) @@ -108,7 +108,7 @@ def test_cute_dsl_mla_decode_fp16( """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() - from flashinfer.mla.cute_dsl import cute_dsl_mla_decode + from flashinfer.cute_dsl.attention import cute_dsl_mla_decode torch.manual_seed(42) device = torch.device("cuda") @@ -189,15 +189,15 @@ def test_cute_dsl_mla_decode_fp16( torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) -@pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_variable_seq_len( - batch_size, seq_len_k, page_size=128, enable_pdl=False -): +@pytest.mark.parametrize("batch_size", [1, 4, 16]) +@pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) +@pytest.mark.parametrize("page_size", [32, 128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size, dtype): """Test MLA decode with variable sequence lengths across the batch.""" skip_if_unsupported() - from flashinfer.mla.cute_dsl import cute_dsl_mla_decode + from flashinfer.cute_dsl.attention import cute_dsl_mla_decode torch.manual_seed(42) device = torch.device("cuda") @@ -210,11 +210,8 @@ def test_cute_dsl_mla_decode_variable_seq_len( output_scale = 1.0 D_qk = latent_dim + rope_dim - query = torch.randn( - batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device - ) + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=dtype, device=device) - # Variable sequence lengths max_seq_len = seq_len_k seq_lens = torch.randint( page_size, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device @@ -222,9 +219,7 @@ def test_cute_dsl_mla_decode_variable_seq_len( max_pages_per_batch = (max_seq_len + page_size - 1) // page_size total_pages = max_pages_per_batch * batch_size + 10 - kv_cache = torch.randn( - total_pages, page_size, D_qk, dtype=torch.float16, device=device - ) + kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=dtype, device=device) block_tables = torch.zeros( batch_size, max_pages_per_batch, dtype=torch.int32, device=device @@ -247,10 +242,8 @@ def test_cute_dsl_mla_decode_variable_seq_len( softmax_scale=softmax_scale, output_scale=output_scale, is_var_seq=True, - enable_pdl=enable_pdl, ) - # Reference kv_flat = kv_cache.reshape(-1, D_qk) c_latent_ref = kv_flat[:, :latent_dim] c_rope_ref = kv_flat[:, latent_dim:] @@ -268,9 +261,9 @@ def test_cute_dsl_mla_decode_variable_seq_len( output_scale, page_size, ) - ref_out_fp16 = ref_out.to(torch.float16) + ref_out_cast = ref_out.to(dtype) - torch.testing.assert_close(out, ref_out_fp16, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) @pytest.mark.parametrize("batch_size", [1, 4]) @@ -403,14 +396,14 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64) @pytest.mark.parametrize("batch_size", [1, 4]) -@pytest.mark.parametrize("seq_len_k", [128, 512]) -@pytest.mark.parametrize("page_size", [128]) -@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) +@pytest.mark.parametrize("page_size", [64, 128]) +@pytest.mark.parametrize("enable_pdl", [False]) def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): """Test FP8 MLA decode kernel against FP32 reference.""" skip_if_unsupported() - from flashinfer.mla.cute_dsl import cute_dsl_mla_decode + from flashinfer.cute_dsl.attention import cute_dsl_mla_decode torch.manual_seed(42) device = torch.device("cuda") @@ -487,3 +480,799 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): torch.testing.assert_close( out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 ) + + +# --------------------------------------------------------------------------- +# Variant tests: score_mod, update_statistics, transform_output +# --------------------------------------------------------------------------- + + +def torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale, + output_scale, + page_size, + score_mod_fn=None, + sink=None, +): + """PyTorch reference for MLA decode with variant hooks. + + Args: + score_mod_fn: callable(score, batch_idx, qo_idx, kv_idx, head_idx) -> score + sink: (num_heads,) tensor for attention sink + """ + B, q_len, H, latent_dim = q_nope.shape + + outputs = [] + for b in range(B): + seq_len = cache_seqs[b].item() + num_pages_needed = (seq_len + page_size - 1) // page_size + + page_indices = page_table[b, :num_pages_needed] + kv_indices = [] + for p in page_indices: + start = p.item() * page_size + kv_indices.extend(range(start, start + page_size)) + kv_indices = kv_indices[:seq_len] + kv_indices_t = torch.tensor(kv_indices, device=q_nope.device) + + k_latent = c_latent[kv_indices_t] + k_rope = c_rope[kv_indices_t] + + q_lat_b = q_nope[b] + q_rope_b = q_rope[b] + + attn_latent = torch.einsum("qhd,kd->qhk", q_lat_b.float(), k_latent.float()) + attn_rope = torch.einsum("qhd,kd->qhk", q_rope_b.float(), k_rope.float()) + attn = attn_latent + attn_rope + + if score_mod_fn is not None: + for qi in range(q_len): + for hi in range(H): + for ki in range(seq_len): + attn[qi, hi, ki] = score_mod_fn(attn[qi, hi, ki], b, qi, ki, hi) + + attn = attn * softmax_scale + + if sink is not None: + sink_dev = sink.to(q_nope.device).float() + for qi in range(q_len): + for hi in range(H): + scores = attn[qi, hi, :] + # sink[hi] is in natural-log domain: effective weight = exp(sink[hi]). + # scores are already multiplied by softmax_scale, so place + # sink[hi] directly as the virtual score (torch.softmax + # computes exp(x_i) / sum(exp(x_j))). + virtual_scores = torch.cat([sink_dev[hi].unsqueeze(0), scores]) + weights = torch.softmax(virtual_scores, dim=-1) + real_weights = weights[1:] + out_qh = torch.einsum("k,kd->d", real_weights, k_latent.float()) + out_qh = out_qh * output_scale + if qi == 0 and hi == 0: + out_b = torch.zeros(q_len, H, latent_dim, device=q_nope.device) + out_b[qi, hi] = out_qh + outputs.append(out_b) + continue + + attn = F.softmax(attn, dim=-1) + out_b = torch.einsum("qhk,kd->qhd", attn, k_latent.float()) + out_b = out_b * output_scale + outputs.append(out_b) + + return torch.stack(outputs, dim=0) + + +def _make_mla_test_data(batch_size, seq_len_k, page_size, dtype, q_len=1): + """Create standard MLA test data (query, kv_cache, block_tables, seq_lens).""" + device = torch.device("cuda") + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + D_qk = latent_dim + rope_dim + + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=dtype, device=device) + + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + kv_cache = torch.randn( + total_pages, + page_size, + D_qk, + dtype=dtype, + device=device, + ) + + block_tables = torch.zeros( + batch_size, + num_pages_per_batch, + dtype=torch.int32, + device=device, + ) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) + + return ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [256, 2048]) +@pytest.mark.parametrize("page_size", [64, 128]) +def test_cute_dsl_mla_decode_alibi(batch_size, seq_len_k, page_size): + """Test MLA decode with ALiBi variant (score_mod with per-head slopes).""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import ALiBiAttention + + torch.manual_seed(42) + dtype = torch.bfloat16 + + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size, seq_len_k, page_size, dtype) + + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + alibi_slopes = ALiBiAttention.get_slopes(num_heads).cuda() + variant = ALiBiAttention(alibi_slopes) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + slopes_cpu = alibi_slopes.float() + + def alibi_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): + return score + slopes_cpu[head_idx].item() * (kv_idx - qo_idx) + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + score_mod_fn=alibi_score_mod, + ) + ref_out_cast = ref_out.to(dtype) + + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [256, 2048]) +@pytest.mark.parametrize("page_size", [64, 128]) +def test_cute_dsl_mla_decode_soft_capping(batch_size, seq_len_k, page_size): + """Test MLA decode with SoftCapping variant (score_mod, no extra_params).""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import SoftCappingAttention + + torch.manual_seed(42) + dtype = torch.bfloat16 + + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size, seq_len_k, page_size, dtype) + + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + cap = 50.0 + + variant = SoftCappingAttention(cap=cap) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + import math + + def soft_capping_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): + return cap * math.tanh(score.item() / cap) + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + score_mod_fn=soft_capping_score_mod, + ) + ref_out_cast = ref_out.to(dtype) + + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [256, 2048]) +@pytest.mark.parametrize("page_size", [64, 128]) +def test_cute_dsl_mla_decode_attention_sink(batch_size, seq_len_k, page_size): + """Test MLA decode with AttentionWithSink (update_statistics + transform_output).""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import AttentionWithSink + + torch.manual_seed(42) + dtype = torch.bfloat16 + num_heads = 128 + + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size, seq_len_k, page_size, dtype) + + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + sink = torch.randn((num_heads,), dtype=dtype, device="cuda") + variant = AttentionWithSink(sink) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + sink=sink.cpu(), + ) + ref_out_cast = ref_out.to(dtype) + + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [256, 2048]) +@pytest.mark.parametrize("page_size", [64, 128]) +def test_cute_dsl_mla_decode_rpe(batch_size, seq_len_k, page_size): + """Test MLA decode with RPEAttention (score_mod with 2-D per-head bias table).""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import RPEAttention + + torch.manual_seed(42) + dtype = torch.bfloat16 + + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size, seq_len_k, page_size, dtype) + + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + max_rel_dist = 64 + table_size = 2 * max_rel_dist + 1 + rpe_table = ( + torch.randn((num_heads, table_size), dtype=torch.float32, device="cuda") * 0.1 + ) + variant = RPEAttention(rpe_table, max_rel_dist=max_rel_dist) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + rpe_cpu = rpe_table.float().cpu() + + def rpe_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): + rel_pos = kv_idx - qo_idx + max_rel_dist + rel_pos = max(0, min(rel_pos, table_size - 1)) + return score + rpe_cpu[head_idx, rel_pos].item() + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + score_mod_fn=rpe_score_mod, + ) + ref_out_cast = ref_out.to(dtype) + + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) + + +# --------------------------------------------------------------------------- +# FP8 variant tests +# --------------------------------------------------------------------------- + + +def _make_fp8_mla_inputs( + batch_size, seq_len_k, page_size, num_heads=128, latent_dim=512, rope_dim=64 +): + """Helper to create FP8 query/kv/block_tables for variant tests.""" + device = torch.device("cuda") + D_qk = latent_dim + rope_dim + query = ( + torch.randn(batch_size, 1, num_heads, D_qk, dtype=torch.float16, device=device) + * 0.1 + ).to(torch.float8_e4m3fn) + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + kv_cache = ( + torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + * 0.1 + ).to(torch.float8_e4m3fn) + block_tables = torch.zeros( + batch_size, num_pages_per_batch, dtype=torch.int32, device=device + ) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) + return query, kv_cache, block_tables, seq_lens, workspace_buffer + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +@pytest.mark.parametrize("page_size", [64]) +def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size): + """Test FP8 MLA decode with ALiBi variant.""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import ALiBiAttention + + torch.manual_seed(42) + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + query, kv_cache, block_tables, seq_lens, workspace_buffer = _make_fp8_mla_inputs( + batch_size, seq_len_k, page_size + ) + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + alibi_slopes = ALiBiAttention.get_slopes(num_heads).cuda() + variant = ALiBiAttention(alibi_slopes) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=query.dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + D_qk = latent_dim + rope_dim + kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim].to(torch.float32) + q_rope = query[..., latent_dim:].to(torch.float32) + + slopes_cpu = alibi_slopes.cpu().float() + + def alibi_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): + return score + slopes_cpu[head_idx].item() * (kv_idx - qo_idx) + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + score_mod_fn=alibi_score_mod, + ) + torch.testing.assert_close( + out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 + ) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +@pytest.mark.parametrize("page_size", [64]) +def test_cute_dsl_mla_decode_fp8_soft_capping(batch_size, seq_len_k, page_size): + """Test FP8 MLA decode with SoftCapping variant.""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import SoftCappingAttention + + torch.manual_seed(42) + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + query, kv_cache, block_tables, seq_lens, workspace_buffer = _make_fp8_mla_inputs( + batch_size, seq_len_k, page_size + ) + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + cap = 50.0 + variant = SoftCappingAttention(cap=cap) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=query.dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + import math + + D_qk = latent_dim + rope_dim + kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim].to(torch.float32) + q_rope = query[..., latent_dim:].to(torch.float32) + + def capping_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): + return cap * math.tanh(score.item() / cap) + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + score_mod_fn=capping_score_mod, + ) + torch.testing.assert_close( + out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 + ) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +@pytest.mark.parametrize("page_size", [64]) +def test_cute_dsl_mla_decode_fp8_attention_sink(batch_size, seq_len_k, page_size): + """Test FP8 MLA decode with AttentionWithSink variant.""" + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import AttentionWithSink + + torch.manual_seed(42) + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + query, kv_cache, block_tables, seq_lens, workspace_buffer = _make_fp8_mla_inputs( + batch_size, seq_len_k, page_size + ) + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + sink = torch.randn((num_heads,), dtype=torch.bfloat16, device="cuda") + variant = AttentionWithSink(sink) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=query.dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + D_qk = latent_dim + rope_dim + kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim].to(torch.float32) + q_rope = query[..., latent_dim:].to(torch.float32) + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + sink=sink.cpu().to(torch.float32), + ) + torch.testing.assert_close( + out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 + ) + + +# --------------------------------------------------------------------------- +# Regression: SoftCapping with non-tile-aligned seq_len +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("seq_len_k", [200]) +@pytest.mark.parametrize("page_size", [64]) +def test_cute_dsl_mla_decode_soft_capping_small_cap(batch_size, seq_len_k, page_size): + """Expose SoftCapping + last-tile masking interaction. + + With a small cap and seq_len_k not aligned to the 128-element MMA tile, + masked-out positions (beyond seq_len_k) are transformed from -inf to -cap + by score_mod. When cap is small (e.g. 1.0), -cap sits within the range of + valid scores, giving masked positions non-negligible softmax probability. + Those positions carry garbage KV data, corrupting the output. + + This test uses cap=1.0 and seq_len_k=200 (last tile has 72 valid + 56 + masked elements). The reference only sums over valid positions, so any + leakage from masked positions shows up as a numerical mismatch. + """ + skip_if_unsupported() + + from flashinfer.cute_dsl.attention.wrappers.batch_mla import ( + BatchMLADecodeCuteDSLWrapper, + ) + from flashinfer.cute_dsl.attention.fusion.variant import SoftCappingAttention + + torch.manual_seed(42) + dtype = torch.bfloat16 + + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size, seq_len_k, page_size, dtype) + + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + cap = 1.0 + + variant = SoftCappingAttention(cap=cap) + + wrapper = BatchMLADecodeCuteDSLWrapper(workspace_buffer) + wrapper.plan( + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + num_heads=num_heads, + page_size=page_size, + q_dtype=dtype, + is_var_seq=False, + variant=variant, + ) + out = wrapper.run( + q=query, + kv_cache=kv_cache, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + import math + + def soft_capping_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): + return cap * math.tanh(score.item() / cap) + + ref_out = torch_reference_mla_with_variant( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + score_mod_fn=soft_capping_score_mod, + ) + ref_out_cast = ref_out.to(dtype) + + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) diff --git a/tests/attention/test_modular_fmha_prefill.py b/tests/attention/test_modular_fmha_prefill.py new file mode 100644 index 0000000000..2a6f847a0b --- /dev/null +++ b/tests/attention/test_modular_fmha_prefill.py @@ -0,0 +1,1666 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the refactored flashinfer.cute_dsl.attention package. + +Covers: basic prefill (various q/kv/batch combos, GQA vs MHA, causal), +variable-length sequences, output transform, logits transform (sigmoid), +and attention sink. + +Each unique (mask_type, fusion) combination triggers one JIT compilation +(~30s). The test matrix is designed to reuse compiled kernels across +many runtime configurations so the full suite runs in a few minutes. +""" + +import math + +import pytest +import torch + +from flashinfer.cute_dsl import is_cute_dsl_available +from flashinfer.utils import is_sm100a_supported + +if not is_cute_dsl_available(): + pytest.skip("CuTe DSL not available", allow_module_level=True) + +from tests.test_helpers.sink_attention_reference import sink_softmax +import cutlass.cute as cute + +from flashinfer.cute_dsl.attention import ( + BatchPrefillCuteDSLWrapper, + AttentionVariant, + AttentionWithSink, + SigmoidAttention, + SigmoidTanhAttention, + ALiBiAttention, + RPEAttention, +) + + +# --------------------------------------------------------------------------- +# Reference implementations +# --------------------------------------------------------------------------- + + +def attention_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, + sink: torch.Tensor | None = None, +) -> torch.Tensor: + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + num_kv_heads = k.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + + if num_qo_heads > num_kv_heads: + assert num_qo_heads % num_kv_heads == 0 + group_size = num_qo_heads // num_kv_heads + k = torch.repeat_interleave(k, group_size, dim=1) + v = torch.repeat_interleave(v, group_size, dim=1) + + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + * sm_scale + ) + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len, kv_len, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) + if sink is not None: + p = sink_softmax(logits, sink) + else: + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref, lse_ref * math.log2(math.e) + + +def attention_varlen_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + causal: bool, + sm_scale: float, +) -> torch.Tensor: + batch_size = qo_indptr.shape[0] - 1 + nnz_qo = qo_indptr[-1].item() + o = torch.empty(nnz_qo, *q.shape[1:-1], v.shape[-1], device=q.device, dtype=q.dtype) + lse = torch.empty(nnz_qo, q.shape[1], device=q.device, dtype=torch.float32) + + for i in range(batch_size): + o_i, lse_i = attention_ref( + 1, + q[qo_indptr[i] : qo_indptr[i + 1]], + k[kv_indptr[i] : kv_indptr[i + 1]], + v[kv_indptr[i] : kv_indptr[i + 1]], + causal, + sm_scale, + ) + + lse_i = lse_i.flatten(0, 1) + o[qo_indptr[i] : qo_indptr[i + 1]] = o_i + lse[qo_indptr[i] : qo_indptr[i + 1]] = lse_i + + return o, lse + + +def attention_sigmoid_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sigmoid_scale: float, + sigmoid_bias: float = 0.0, +) -> torch.Tensor: + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + num_kv_heads = k.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + + if num_qo_heads > num_kv_heads: + assert num_qo_heads % num_kv_heads == 0 + group_size = num_qo_heads // num_kv_heads + k = torch.repeat_interleave(k, group_size, dim=1) + v = torch.repeat_interleave(v, group_size, dim=1) + + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + * sigmoid_scale + ) + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len, kv_len, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + + p = torch.sigmoid(logits + sigmoid_bias) + + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _skip_if_unsupported(qo_len, kv_len, causal): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + +HEAD_DIM = 128 +NUM_QO_HEADS = 32 +SM_SCALE = 1.0 / math.sqrt(HEAD_DIM) +DTYPE = torch.bfloat16 +ATOL = 1e-2 +RTOL = 1e-2 + + +# --------------------------------------------------------------------------- +# 1. Basic prefill — curated (batch, qo, kv) combos covering tile boundaries +# --------------------------------------------------------------------------- + +BASIC_SHAPE_PARAMS = [ + # (batch, qo_len, kv_len) + (1, 64, 64), # single batch, sub-tile + (1, 128, 128), # single batch, exact tile + (1, 177, 977), # single batch, multi-tile + (2, 256, 256), # small batch, multi-tile + (9, 1, 1), # many batches, minimal sizes + (9, 64, 1), # many batches, sub-tile Q, minimal KV + (9, 177, 1), # multi-tile Q, minimal KV (regression for accumulate bug) + (9, 177, 64), # multi-tile Q, sub-tile KV + (9, 177, 128), # multi-tile Q, exact-tile KV boundary + (9, 177, 129), # multi-tile Q, just-over-tile KV + (9, 256, 17), # multi-tile Q, small non-aligned KV + (9, 256, 256), # multi-tile both, aligned + (9, 177, 544), # multi-tile, causal_offset=367 not tile-aligned (regression) + (9, 256, 544), # multi-tile, causal_offset=288 not tile-aligned + (9, 177, 1999), # multi-tile both, large non-aligned KV + (17, 177, 977), # large batch, multi-tile +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", BASIC_SHAPE_PARAMS) +@pytest.mark.parametrize("num_kv_heads", [8, 32]) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill( + batch_size, + qo_len, + kv_len, + num_kv_heads, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + ) + o = wrapper.run(q, k, v) + o_ref, _ = attention_ref(batch_size, q, k, v, causal, SM_SCALE) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + # Verify the out= in-place contract: the user's pre-allocated tensor + # must be populated with the results and returned as-is. + o_buffer = torch.empty_like(o) + o_ret = wrapper.run(q, k, v, out=o_buffer) + assert o_ret.data_ptr() == o_buffer.data_ptr(), ( + "run(out=...) must return the same tensor object" + ) + torch.testing.assert_close(o_buffer, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 2. Variable-length sequences — a few representative indptr patterns +# --------------------------------------------------------------------------- + +VARLEN_INDPTR_PARAMS_COMPACT = [ + [0, 7], # single very short seq + [0, 1284], # single long seq + [0, 1298, 2638], # 2 seqs + [0, 1350, 2667, 4003, 5347, 6631, 7919, 9208, 10524], # 8 seqs + [0, 1300, 2614, 3924], # 3 seqs, short + [ + 0, + 1536, + 3061, + 4578, + 6177, + 7774, + 9378, + 10958, + 12636, + 14292, + 15954, + ], # 10 seqs, varied +] + + +@pytest.mark.parametrize("indptr", VARLEN_INDPTR_PARAMS_COMPACT) +@pytest.mark.parametrize("num_kv_heads", [8, 32]) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_varlen( + indptr, + num_kv_heads, + causal, +): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + + torch.manual_seed(42) + sm_scale = SM_SCALE + + q = torch.randn(indptr[-1], NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda") + k = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + v = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + + qo_indptr = torch.tensor(indptr, device="cuda", dtype=torch.int32) + kv_indptr = qo_indptr + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(1, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=sm_scale, + q_data_type=DTYPE, + kv_data_type=DTYPE, + ) + o = wrapper.run(q, k, v) + + gqa_group_ratio = NUM_QO_HEADS // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + + o_ref, _ = attention_varlen_ref( + q, k_repeated, v_repeated, qo_indptr, kv_indptr, causal, sm_scale + ) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 3. Output transform +# --------------------------------------------------------------------------- + +FUSION_SHAPE_PARAMS = [ + # (batch, qo_len, kv_len) — smaller matrix, covers key boundaries + (1, 177, 977), + (9, 177, 64), + (9, 256, 256), + (9, 177, 128), +] + + +class _ScaleBy2TestVariant(AttentionVariant): + """Test-only: multiply output by 2x to verify output_transform hook.""" + + has_output_transform = True + + @cute.jit + def transform_output(self, output, batch_idx, qo_idx, qo_head_idx, m, rcp_d, scale): + return output * scale * 2.0 * rcp_d + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", FUSION_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_output_transform( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=_ScaleBy2TestVariant(), + ) + o = wrapper.run(q, k, v) + + o_ref, _ = attention_ref(batch_size, q, k, v, causal, 1.0) + o_ref_transform = o_ref * 2.0 + + torch.testing.assert_close(o, o_ref_transform, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 4. Logits transform (sigmoid) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", FUSION_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_logits_transform( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SigmoidAttention(scale=1.0, bias=0.0), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_sigmoid_ref(batch_size, q, k, v, causal, 1.0, 0.0) + + # Sigmoid logits transform has known ~1% element accuracy limitations + # (documented in PR #1549) + torch.testing.assert_close(o, o_ref, rtol=0.15, atol=0.15) + + +@pytest.mark.parametrize( + "batch_size,qo_len,kv_len", + [ + (1, 128, 128), + (9, 256, 256), + ], +) +@pytest.mark.parametrize("bias", [-5.0, -2.0, 1.0]) +def test_attention_prefill_sigmoid_bias(batch_size, qo_len, kv_len, bias): + """Regression test: SigmoidAttention bias must match torch.sigmoid semantics. + + The bias parameter should produce σ(score * scale + bias), matching + the C++ FlashSigmoid which converts both scale and bias to log-base-2 + via multiplication by log2(e). A previous implementation only converted + scale, effectively attenuating the bias by ln(2) ≈ 0.693. + """ + _skip_if_unsupported(qo_len, kv_len, causal=False) + num_kv_heads = 8 + scale = 1.0 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=False, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SigmoidAttention(scale=scale, bias=bias), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_sigmoid_ref(batch_size, q, k, v, False, scale, bias) + + torch.testing.assert_close(o, o_ref, rtol=0.15, atol=0.15) + + +# --------------------------------------------------------------------------- +# 4b. Logits transform (sigmoid via tanh) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", FUSION_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_sigmoid_tanh( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SigmoidTanhAttention(scale=1.0, bias=0.0), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_sigmoid_ref(batch_size, q, k, v, causal, 1.0, 0.0) + + torch.testing.assert_close(o, o_ref, rtol=0.15, atol=0.15) + + +@pytest.mark.parametrize( + "batch_size,qo_len,kv_len", + [ + (1, 128, 128), + (9, 256, 256), + ], +) +@pytest.mark.parametrize("bias", [-5.0, -2.0, 1.0]) +def test_attention_prefill_sigmoid_tanh_bias(batch_size, qo_len, kv_len, bias): + """Regression test: SigmoidTanhAttention bias must match torch.sigmoid semantics.""" + _skip_if_unsupported(qo_len, kv_len, causal=False) + num_kv_heads = 8 + scale = 1.0 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=False, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SigmoidTanhAttention(scale=scale, bias=bias), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_sigmoid_ref(batch_size, q, k, v, False, scale, bias) + + torch.testing.assert_close(o, o_ref, rtol=0.15, atol=0.15) + + +# --------------------------------------------------------------------------- +# 5. Attention sink +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", FUSION_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_attention_sink( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + sink = torch.randn((NUM_QO_HEADS,), dtype=DTYPE, device="cuda") + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=AttentionWithSink(sink), + ) + o = wrapper.run(q, k, v) + + o_ref, _ = attention_ref(batch_size, q, k, v, causal, SM_SCALE, sink=sink) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 6. Float16 dtype +# --------------------------------------------------------------------------- + +FP16_SHAPE_PARAMS = [ + (1, 128, 128), + (9, 256, 256), + (1, 177, 977), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", FP16_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_fp16( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + dtype = torch.float16 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=dtype, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=dtype, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=dtype, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=dtype, + kv_data_type=dtype, + ) + o = wrapper.run(q, k, v) + o_ref, _ = attention_ref(batch_size, q, k, v, causal, SM_SCALE) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 7. Sliding window mask +# --------------------------------------------------------------------------- + + +def attention_sliding_window_ref( + batch_size, + q, + k, + v, + window_left, + sm_scale, +): + """Reference for symmetric sliding window: |kv_idx - (q_idx + offset)| <= window_left. + + When qo_len != kv_len, Q positions are right-aligned to KV: q_idx maps + to kv position q_idx + (kv_len - qo_len). + """ + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + num_kv_heads = k.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + + if num_qo_heads > num_kv_heads: + group_size = num_qo_heads // num_kv_heads + k = torch.repeat_interleave(k, group_size, dim=1) + v = torch.repeat_interleave(v, group_size, dim=1) + + logits = ( + torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + * sm_scale + ) + + qk_offset = kv_len - qo_len + q_idx = torch.arange(qo_len, device=q.device).unsqueeze(1) + k_idx = torch.arange(kv_len, device=q.device).unsqueeze(0) + mask = torch.abs(k_idx - (q_idx + qk_offset)) <= window_left + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + return o_ref + + +SLIDING_WINDOW_PARAMS = [ + # (batch, qo_len, kv_len, window_left) + (1, 256, 256, 64), + (1, 256, 256, 128), + (9, 256, 256, 100), + (1, 512, 512, 200), + # qo_len != kv_len (Q right-aligned to KV, as in append/prefill-with-cache) + (1, 128, 256, 64), + (1, 128, 512, 100), + (1, 256, 512, 128), + (3, 128, 384, 80), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len,window_left", SLIDING_WINDOW_PARAMS) +def test_attention_prefill_sliding_window( + batch_size, + qo_len, + kv_len, + window_left, +): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=False, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + window_left=window_left, + ) + o = wrapper.run(q, k, v) + o_ref = attention_sliding_window_ref(batch_size, q, k, v, window_left, SM_SCALE) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 8. Head dimension 64 +# --------------------------------------------------------------------------- + +HEAD64_SHAPE_PARAMS = [ + (1, 128, 128), + (9, 256, 256), + (1, 177, 977), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", HEAD64_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_head_dim_64( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + head_dim = 64 + num_kv_heads = 8 + sm_scale = 1.0 / math.sqrt(head_dim) + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, head_dim, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, head_dim, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + head_dim, + head_dim_vo=head_dim, + causal=causal, + sm_scale=sm_scale, + q_data_type=DTYPE, + kv_data_type=DTYPE, + ) + o = wrapper.run(q, k, v) + o_ref, _ = attention_ref(batch_size, q, k, v, causal, sm_scale) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 9. Variable-length + logits transform (sigmoid) +# --------------------------------------------------------------------------- + +VARLEN_FUSION_INDPTRS = [ + [0, 1298, 2638], + [0, 1350, 2667, 4003, 5347, 6631, 7919, 9208, 10524], +] + + +def attention_varlen_sigmoid_ref( + q, + k, + v, + qo_indptr, + kv_indptr, + causal, + sigmoid_scale, + sigmoid_bias=0.0, +): + batch_size = qo_indptr.shape[0] - 1 + nnz_qo = qo_indptr[-1].item() + o = torch.empty(nnz_qo, *q.shape[1:-1], v.shape[-1], device=q.device, dtype=q.dtype) + for i in range(batch_size): + o_i = attention_sigmoid_ref( + 1, + q[qo_indptr[i] : qo_indptr[i + 1]], + k[kv_indptr[i] : kv_indptr[i + 1]], + v[kv_indptr[i] : kv_indptr[i + 1]], + causal, + sigmoid_scale, + sigmoid_bias, + ) + o[qo_indptr[i] : qo_indptr[i + 1]] = o_i + return o + + +@pytest.mark.parametrize("indptr", VARLEN_FUSION_INDPTRS) +def test_attention_prefill_varlen_logits_transform(indptr): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn(indptr[-1], NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda") + k = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + v = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + + qo_indptr = torch.tensor(indptr, device="cuda", dtype=torch.int32) + kv_indptr = qo_indptr + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=False, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SigmoidAttention(scale=1.0, bias=0.0), + ) + o = wrapper.run(q, k, v) + + gqa_group_ratio = NUM_QO_HEADS // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + o_ref = attention_varlen_sigmoid_ref( + q, + k_repeated, + v_repeated, + qo_indptr, + kv_indptr, + False, + 1.0, + 0.0, + ) + + torch.testing.assert_close(o, o_ref, rtol=0.15, atol=0.15) + + +# --------------------------------------------------------------------------- +# 9b. Variable-length + logits transform (sigmoid via tanh) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("indptr", VARLEN_FUSION_INDPTRS) +def test_attention_prefill_varlen_sigmoid_tanh(indptr): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn(indptr[-1], NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda") + k = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + v = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + + qo_indptr = torch.tensor(indptr, device="cuda", dtype=torch.int32) + kv_indptr = qo_indptr + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=False, + sm_scale=1.0, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SigmoidTanhAttention(scale=1.0, bias=0.0), + ) + o = wrapper.run(q, k, v) + + gqa_group_ratio = NUM_QO_HEADS // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + o_ref = attention_varlen_sigmoid_ref( + q, + k_repeated, + v_repeated, + qo_indptr, + kv_indptr, + False, + 1.0, + 0.0, + ) + + torch.testing.assert_close(o, o_ref, rtol=0.15, atol=0.15) + + +# --------------------------------------------------------------------------- +# 10. Variable-length + attention sink +# --------------------------------------------------------------------------- + + +def attention_varlen_sink_ref( + q, + k, + v, + qo_indptr, + kv_indptr, + causal, + sm_scale, + sink, +): + batch_size = qo_indptr.shape[0] - 1 + nnz_qo = qo_indptr[-1].item() + o = torch.empty(nnz_qo, *q.shape[1:-1], v.shape[-1], device=q.device, dtype=q.dtype) + for i in range(batch_size): + o_i, _ = attention_ref( + 1, + q[qo_indptr[i] : qo_indptr[i + 1]], + k[kv_indptr[i] : kv_indptr[i + 1]], + v[kv_indptr[i] : kv_indptr[i + 1]], + causal, + sm_scale, + sink=sink, + ) + o[qo_indptr[i] : qo_indptr[i + 1]] = o_i + return o + + +@pytest.mark.parametrize("indptr", VARLEN_FUSION_INDPTRS) +def test_attention_prefill_varlen_attention_sink(indptr): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("SM100A is not supported on this device") + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn(indptr[-1], NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda") + k = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + v = torch.randn(indptr[-1], num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda") + sink = torch.randn((NUM_QO_HEADS,), dtype=DTYPE, device="cuda") + + qo_indptr = torch.tensor(indptr, device="cuda", dtype=torch.int32) + kv_indptr = qo_indptr + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=False, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=AttentionWithSink(sink), + ) + o = wrapper.run(q, k, v) + + gqa_group_ratio = NUM_QO_HEADS // num_kv_heads + k_repeated = torch.repeat_interleave(k, gqa_group_ratio, dim=1) + v_repeated = torch.repeat_interleave(v, gqa_group_ratio, dim=1) + o_ref = attention_varlen_sink_ref( + q, + k_repeated, + v_repeated, + qo_indptr, + kv_indptr, + False, + SM_SCALE, + sink, + ) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 11. Attention sink with MHA (num_kv_heads == NUM_QO_HEADS) +# --------------------------------------------------------------------------- + +SINK_MHA_SHAPE_PARAMS = [ + (9, 256, 256), + (1, 177, 977), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", SINK_MHA_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_attention_sink_mha( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = NUM_QO_HEADS + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + sink = torch.randn((NUM_QO_HEADS,), dtype=DTYPE, device="cuda") + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=AttentionWithSink(sink), + ) + o = wrapper.run(q, k, v) + + o_ref, _ = attention_ref(batch_size, q, k, v, causal, SM_SCALE, sink=sink) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 12. ALiBi (score_mod hook) +# --------------------------------------------------------------------------- + + +def attention_alibi_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, + alibi_slopes: torch.Tensor, +) -> torch.Tensor: + """Reference ALiBi attention: adds slope * (kv_pos - qo_pos) before softmax.""" + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + num_kv_heads = k.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + + if num_qo_heads > num_kv_heads: + group_size = num_qo_heads // num_kv_heads + k = torch.repeat_interleave(k, group_size, dim=1) + v = torch.repeat_interleave(v, group_size, dim=1) + + logits = torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + + qo_pos = torch.arange(qo_len, device=q.device).view(1, 1, -1, 1) + kv_pos = torch.arange(kv_len, device=q.device).view(1, 1, 1, -1) + alibi_bias = alibi_slopes.view(1, -1, 1, 1) * (kv_pos - qo_pos) + logits = (logits + alibi_bias) * sm_scale + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len, kv_len, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref + + +ALIBI_SHAPE_PARAMS = [ + (1, 128, 128), + (1, 177, 977), + (9, 256, 256), + (9, 177, 544), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", ALIBI_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_alibi( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + alibi_slopes = ALiBiAttention.get_slopes(NUM_QO_HEADS).cuda() + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=ALiBiAttention(alibi_slopes), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_alibi_ref( + batch_size, + q, + k, + v, + causal, + SM_SCALE, + alibi_slopes, + ) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 13. RPE (Relative Positional Encoding — 2-D params) +# --------------------------------------------------------------------------- + + +def attention_rpe_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, + rpe_table: torch.Tensor, + max_rel_dist: int, +) -> torch.Tensor: + """Reference RPE attention: adds rpe_table[head, clamp(kv-qo+offset)] before softmax.""" + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + num_kv_heads = k.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + + if num_qo_heads > num_kv_heads: + group_size = num_qo_heads // num_kv_heads + k = torch.repeat_interleave(k, group_size, dim=1) + v = torch.repeat_interleave(v, group_size, dim=1) + + logits = torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + + qo_pos = torch.arange(qo_len, device=q.device).view(1, 1, -1, 1) + kv_pos = torch.arange(kv_len, device=q.device).view(1, 1, 1, -1) + rel_pos = (kv_pos - qo_pos + max_rel_dist).clamp(0, 2 * max_rel_dist) + rpe_bias = rpe_table[:, rel_pos.squeeze(0).squeeze(0).long()].unsqueeze(0) + logits = (logits + rpe_bias) * sm_scale + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + else: + mask = torch.ones(qo_len, kv_len, device=q.device) + + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref + + +RPE_SHAPE_PARAMS = [ + (1, 128, 128), + (1, 177, 977), + (9, 256, 256), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", RPE_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False, True]) +def test_attention_prefill_rpe( + batch_size, + qo_len, + kv_len, + causal, +): + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + max_rel_dist = 64 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + rpe_table = ( + torch.randn( + NUM_QO_HEADS, 2 * max_rel_dist + 1, dtype=torch.float32, device="cuda" + ) + * 0.1 + ) + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=RPEAttention(rpe_table, max_rel_dist), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_rpe_ref( + batch_size, + q, + k, + v, + causal, + SM_SCALE, + rpe_table, + max_rel_dist, + ) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +# --------------------------------------------------------------------------- +# 14. SoftCapping regression: non-tile-aligned kv_len +# --------------------------------------------------------------------------- + + +def attention_soft_capping_ref( + batch_size, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, + cap: float, +) -> torch.Tensor: + """Reference SoftCapping attention: cap * tanh(score / cap) before softmax.""" + qo_len = q.shape[0] // batch_size + kv_len = k.shape[0] // batch_size + num_qo_heads = q.shape[1] + num_kv_heads = k.shape[1] + head_dim_qk = q.shape[2] + head_dim_vo = v.shape[2] + + if num_qo_heads > num_kv_heads: + group_size = num_qo_heads // num_kv_heads + k = torch.repeat_interleave(k, group_size, dim=1) + v = torch.repeat_interleave(v, group_size, dim=1) + + logits = torch.einsum( + "bmhd,bnhd->bhmn", + q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), + k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), + ) + + logits = cap * torch.tanh(logits / cap) + logits = logits * sm_scale + + if causal: + mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( + 1 + ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) + logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) + + p = torch.softmax(logits, dim=-1) + o_ref = ( + torch.einsum( + "bhmn,bnhd->bmhd", + p, + v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), + ) + .contiguous() + .view(batch_size * qo_len, num_qo_heads, head_dim_vo) + .to(q) + ) + + return o_ref + + +SOFTCAP_SHAPE_PARAMS = [ + (1, 128, 200), + (1, 200, 200), + (2, 128, 300), +] + + +@pytest.mark.parametrize("batch_size,qo_len,kv_len", SOFTCAP_SHAPE_PARAMS) +@pytest.mark.parametrize("causal", [False]) +def test_attention_prefill_soft_capping_small_cap( + batch_size, + qo_len, + kv_len, + causal, +): + """SoftCapping with small cap and non-tile-aligned kv_len. + + Regression test: score_mod transforms masked -inf to -cap. With cap=1.0 + and kv_len not divisible by 128, masked positions leak into softmax. + """ + _skip_if_unsupported(qo_len, kv_len, causal) + num_kv_heads = 8 + cap = 1.0 + + torch.manual_seed(42) + q = torch.randn( + batch_size * qo_len, NUM_QO_HEADS, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + k = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + v = torch.randn( + batch_size * kv_len, num_kv_heads, HEAD_DIM, dtype=DTYPE, device="cuda" + ) + qo_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * kv_len + ) + + from flashinfer.cute_dsl.attention import SoftCappingAttention + + wrapper = BatchPrefillCuteDSLWrapper( + torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), + ) + wrapper.plan( + qo_indptr, + kv_indptr, + NUM_QO_HEADS, + num_kv_heads, + HEAD_DIM, + head_dim_vo=HEAD_DIM, + causal=causal, + sm_scale=SM_SCALE, + q_data_type=DTYPE, + kv_data_type=DTYPE, + variant=SoftCappingAttention(cap=cap), + ) + o = wrapper.run(q, k, v) + + o_ref = attention_soft_capping_ref( + batch_size, + q, + k, + v, + causal, + SM_SCALE, + cap, + ) + + torch.testing.assert_close(o, o_ref, rtol=RTOL, atol=ATOL) + + +if __name__ == "__main__": + test_attention_prefill(4, 1024, 1024, 8, True) + test_attention_prefill_varlen( + [0, 256, 1024, 2048, 2560], + 32, + True, + )