diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 035a79bdfa..dff642451d 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -140,7 +140,7 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "auto"], + choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "cute-dsl", "auto"], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -1023,7 +1023,7 @@ def testMmFp4(args): run_refcheck = args.refcheck use_128x4_sf_layout = args.use_128x4_sf_layout use_nvfp4 = args.use_nvfp4 - autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"] + autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "cute-dsl", "auto"] res = [] res_dtype = dtype_str_to_torch_dtype(args.out_dtype) @@ -1123,7 +1123,7 @@ def run_backend( mat2_inv_s, mat2_inv_s_trtllm, ): - if backend in ["cudnn", "trtllm", "cutlass", "auto"]: + if backend in ["cudnn", "trtllm", "cutlass", "cute-dsl", "auto"]: return flashinfer.gemm.mm_fp4( a=input_fp4, b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T, diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 4630583982..85fb25f625 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2968,8 +2968,11 @@ def _check_mm_fp4_problem_size( out: Optional[torch.Tensor] = None, # unused block_size: int = 16, use_8x4_sf_layout: bool = False, # unused - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused + backend: Literal[ + "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" + ] = "auto", # unused use_nvfp4: bool = True, + enable_pdl: bool = True, # unused ): # Generic checks ## pre-check the input tensor, block scale tensor and alpha tensor @@ -3025,8 +3028,11 @@ def _cudnn_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, # unused block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused + backend: Literal[ + "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" + ] = "auto", # unused use_nvfp4: bool = True, + enable_pdl: bool = True, # unused ): if use_8x4_sf_layout: raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") @@ -3086,8 +3092,11 @@ def _trtllm_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, # unused block_size: int = 16, # unused use_8x4_sf_layout: bool = False, # unused - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused + backend: Literal[ + "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" + ] = "auto", # unused use_nvfp4: bool = True, + enable_pdl: bool = True, # unused ): if not use_nvfp4: raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.") @@ -3110,8 +3119,11 @@ def _cutlass_gemm_fp4_requirement( out: Optional[torch.Tensor] = None, # unused block_size: int = 16, # unused use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", # unused + backend: Literal[ + "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" + ] = "auto", # unused use_nvfp4: bool = True, + enable_pdl: bool = True, # unused ): if use_8x4_sf_layout: raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") @@ -3120,6 +3132,486 @@ def _cutlass_gemm_fp4_requirement( return True +@supported_compute_capability([100, 103]) +def _cute_dsl_gemm_fp4_requirement( + a: torch.Tensor, # unused + b: torch.Tensor, # unused + a_descale: torch.Tensor, # unused + b_descale: torch.Tensor, # unused + alpha: Optional[torch.Tensor] = None, # unused + out_dtype: torch.dtype = torch.bfloat16, # unused + out: Optional[torch.Tensor] = None, # unused + block_size: int = 16, # unused + use_8x4_sf_layout: bool = False, + backend: Literal[ + "cudnn", "trtllm", "cutlass", "cute-dsl", "auto" + ] = "auto", # unused + use_nvfp4: bool = True, + enable_pdl: bool = True, # unused +): + # cute_dsl backend requires 128x4 scale factor layout (same as cudnn/cutlass). + # The kernel internally uses CUTLASS BlockScaledBasicChunk which expects + # M/N padded to 128, K padded to 4 -- matching FlashInfer's nvfp4_quantize + # with sfLayout=SfLayout.layout_128x4 and do_shuffle=False. + if use_8x4_sf_layout: + raise ValueError("cute_dsl FP4 GEMM only supports 128x4 scale factor layout.") + if not use_nvfp4: + raise ValueError("cute_dsl FP4 GEMM only supports nvfp4 quantization.") + try: + from flashinfer.cute_dsl.utils import is_cute_dsl_available + + if not is_cute_dsl_available(): + raise RuntimeError("CuTe DSL is not available.") + except ImportError as err: + raise RuntimeError("CuTe DSL is not available.") from err + return True + + +# Module-level kernel cache for CuTe DSL GEMM, shared across runner instances. +# Keyed by (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch, +# kernel_type, use_tma_store, enable_pdl, out_dtype). +_CUTE_DSL_MM_FP4_KERNEL_CACHE = {} + + +def _cute_dsl_gemm_fp4_runner( + sm_major: int, + sm_minor: int, + enable_pdl: bool, + out_dtype: torch.dtype, +): + """Create a CuTe DSL FP4 GEMM runner for the cute_dsl backend. + + On SM100: uses the SM100 kernel only. + On SM103: uses both SM100 kernel and the SM103-specific 3xFP4 kernel. + The autotuner selects the best (kernel_type, tile, cluster, swap_ab, prefetch, + use_tma_store) combination. + """ + import cutlass + import cutlass.cute as cute + + from cutlass.cute.runtime import make_ptr + from .kernels.dense_blockscaled_gemm_sm100 import ( + Sm100BlockScaledPersistentDenseGemmKernel, + ) + + sm_version = sm_major * 10 + sm_minor + + # TODO(yunzheq): Re-enable SM103 kernel once cutlass-dsl package includes + # SM103MmaMXF4Op and compatible PersistentTileSchedulerParams. + # To re-enable, remove the `Sm103Kernel = None` line below. + Sm103Kernel = None + # if sm_version == 103: + # try: + # from .kernels.dense_blockscaled_gemm_sm103 import ( + # Sm103BlockScaledPersistentDenseGemmKernel, + # ) + # + # Sm103Kernel = Sm103BlockScaledPersistentDenseGemmKernel + # except ImportError: + # pass + + # Map torch output dtype to cutlass dtype + _torch_to_cutlass_dtype = { + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + } + c_cutlass_dtype = _torch_to_cutlass_dtype.get(out_dtype) + if c_cutlass_dtype is None: + raise ValueError( + f"cute_dsl backend does not support output dtype {out_dtype}. " + f"Supported: torch.bfloat16, torch.float16." + ) + + class CuteDSLFp4GemmRunner(TunableRunner): + """TunableRunner for CuTe DSL block-scaled FP4 dense GEMM. + + Tactics are tuples: + (mma_tiler_mn, cluster_shape_mn, swap_ab, use_prefetch, kernel_type, use_tma_store) + where: + - kernel_type: "sm100" or "sm103" + - use_tma_store: None for sm100, True/False for sm103 + """ + + def __init__(self): + pass + + def _get_approximate_cta_nums(self, m, n, tile_mn, cluster_shape_mn): + tile_m, tile_n = tile_mn + cluster_m, cluster_n = cluster_shape_mn + ctas_m = ( + ((m + tile_m - 1) // tile_m + cluster_m - 1) // cluster_m * cluster_m + ) + ctas_n = ( + ((n + tile_n - 1) // tile_n + cluster_n - 1) // cluster_n * cluster_n + ) + return ctas_m * ctas_n + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> list: + (a, b, a_descale, b_descale, alpha, _, out, _, _, _) = inputs + m = a.shape[0] + k_packed = a.shape[1] + n = b.shape[1] + real_k = k_packed * 2 # FP4 packed as uint8 + + sf_vec_size = 16 # NVF4 + ab_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E4M3FN + batch_size = 1 + + # SM100 tactic candidates + mma_tiler_mn_candidates = [ + (128, 64), + (256, 64), + (128, 128), + (256, 128), + (128, 192), + (256, 192), + (128, 256), + (256, 256), + ] + cluster_shape_mn_candidates = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + swap_ab_candidates = [False, True] + use_prefetch_candidates = [False, True] + + # Alignment checks for swap_ab + m_aligned = m % 8 == 0 + n_aligned = n % 8 == 0 + + valid_tactics = [] + + # --- SM100 tactics --- + for mma_tiler_mn in mma_tiler_mn_candidates: + for cluster_shape_mn in cluster_shape_mn_candidates: + for swap_ab in swap_ab_candidates: + # Check alignment for C layout + if not swap_ab and not n_aligned: + continue + if swap_ab and not m_aligned: + continue + + if swap_ab: + c_major = "m" + kernel_m, kernel_n = n, m + else: + c_major = "n" + kernel_m, kernel_n = m, n + + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_cutlass_dtype, + mma_tiler_mn, + cluster_shape_mn, + kernel_m, + kernel_n, + real_k, + batch_size, + "k", + "k", + c_major, + ): + continue + + for use_prefetch in use_prefetch_candidates: + # Prefetch pruning heuristic + if use_prefetch: + cta_nums = self._get_approximate_cta_nums( + kernel_m, kernel_n, mma_tiler_mn, cluster_shape_mn + ) + sm_count = torch.cuda.get_device_properties( + a.device + ).multi_processor_count + cta_wave_ratio = cta_nums / sm_count + if not (0.5 < cta_wave_ratio < 1.0 or real_k >= 8192): + continue + + valid_tactics.append( + ( + mma_tiler_mn, + cluster_shape_mn, + swap_ab, + use_prefetch, + "sm100", + None, + ) + ) + + # --- SM103 tactics (only on SM103) --- + if sm_version == 103 and Sm103Kernel is not None: + sm103_mma_tiler_candidates = [ + (128, 128), + (256, 128), + (128, 256), + (256, 256), + ] + use_tma_store_candidates = [True, False] + + for mma_tiler_mn in sm103_mma_tiler_candidates: + for cluster_shape_mn in cluster_shape_mn_candidates: + for swap_ab in swap_ab_candidates: + if not swap_ab and not n_aligned: + continue + if swap_ab and not m_aligned: + continue + + if swap_ab: + c_major = "m" + kernel_m, kernel_n = n, m + else: + c_major = "n" + kernel_m, kernel_n = m, n + + for use_tma_store in use_tma_store_candidates: + if not Sm103Kernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_cutlass_dtype, + mma_tiler_mn, + cluster_shape_mn, + kernel_m, + kernel_n, + real_k, + batch_size, + "k", + "k", + c_major, + use_tma_store, + ): + continue + + for use_prefetch in [False]: + # SM103 kernel does not have prefetch support + valid_tactics.append( + ( # type: ignore[arg-type] + mma_tiler_mn, + cluster_shape_mn, + swap_ab, + use_prefetch, + "sm103", + use_tma_store, + ) + ) + + return valid_tactics + + def forward( + self, + inputs: List[torch.Tensor], + tactic=None, + do_preparation: bool = False, + **kwargs, + ): + (a, b, a_descale, b_descale, alpha_tensor, _, out, _, _, _) = inputs + m = a.shape[0] + k_packed = a.shape[1] + n = b.shape[1] + real_k = k_packed * 2 + + sf_vec_size = 16 + batch_size = 1 + + if tactic is None or tactic == -1: + # Fallback tactic + tactic = ((128, 128), (1, 1), False, False, "sm100", None) + + ( + mma_tiler_mn, + cluster_shape_mn, + swap_ab, + use_prefetch, + kernel_type, + use_tma_store, + ) = tactic + + if swap_ab: + kernel_m, kernel_n = n, m + # Swap A/B tensors and their scale factors + kernel_a, kernel_b = b.T, a.T + kernel_a_sf, kernel_b_sf = b_descale.T, a_descale.T + else: + kernel_m, kernel_n = m, n + # b comes in as (k_packed, n), need (n, k_packed) for the kernel + kernel_a, kernel_b = a, b.T + kernel_a_sf, kernel_b_sf = a_descale, b_descale.T + + # Compute scale factor dimensions (128x4 padded) + sf_m = (kernel_m + 127) // 128 + sf_n = (kernel_n + 127) // 128 + sf_k = (real_k // sf_vec_size + 3) // 4 + + # Cache key for compiled kernel + cache_key = ( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + swap_ab, + use_prefetch, + kernel_type, + use_tma_store, + enable_pdl, + out_dtype, + ) + + if cache_key not in _CUTE_DSL_MM_FP4_KERNEL_CACHE: + # Create kernel instance + if kernel_type == "sm103" and Sm103Kernel is not None: + gemm = Sm103Kernel( # type: ignore[assignment] + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + enable_pdl, + ) + else: + gemm = Sm100BlockScaledPersistentDenseGemmKernel( # type: ignore[assignment] + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + use_prefetch, + enable_pdl, + ) + + # TVM-FFI compilation pattern (commit edb37cd): + # - A, B, C, alpha: make_fake_compact_tensor → torch tensors + # passed directly at runtime via TVM-FFI C-level dlpack + # - SF tensors: make_ptr (complex 6D BlockScaledBasicChunk + # layout can't be expressed as torch tensor) → data_ptr() at runtime + # - Stream: make_fake_stream → automatic env stream at runtime + sym_m = cute.sym_int() + sym_k = cute.sym_int() # k_packed (FP4 stored as uint8) + sym_n = cute.sym_int() + + # A/B: FP4 data stored as uint8 in torch (2 FP4 values per byte). + # Use Uint8 to match torch.uint8 dtype at runtime. The kernel + # wrapper recasts from Uint8 to Float4E2M1FN internally. + a_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, + (sym_m, sym_k), + stride_order=(1, 0), + assumed_align=32, + ) + b_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, + (sym_n, sym_k), + stride_order=(1, 0), + assumed_align=32, + ) + # C: (m, n) layout depends on swap_ab, torch tensor at runtime + if swap_ab: + c_fake = cute.runtime.make_fake_compact_tensor( + c_cutlass_dtype, + (sym_n, sym_m), + stride_order=(0, 1), + assumed_align=16, + ) + else: + c_fake = cute.runtime.make_fake_compact_tensor( + c_cutlass_dtype, + (sym_m, sym_n), + stride_order=(1, 0), + assumed_align=16, + ) + # SF tensors: pointers (complex 6D layout, not expressible as torch tensor) + a_sf_ptr = make_ptr( + cutlass.Float8E4M3FN, 16, cute.AddressSpace.gmem, 16 + ) + b_sf_ptr = make_ptr( + cutlass.Float8E4M3FN, 16, cute.AddressSpace.gmem, 16 + ) + # Alpha: 1-dim tensor, torch tensor at runtime + alpha_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, (1,), assumed_align=4 + ) + + from flashinfer.cute_dsl.utils import get_max_active_clusters + + max_active_clusters = get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Fake stream: auto uses current CUDA stream at runtime + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_gemm = cute.compile( + gemm.wrapper, + a_fake, + b_fake, + c_fake, + sf_m, + sf_n, + sf_k, + batch_size, + a_sf_ptr, + b_sf_ptr, + alpha_fake, + max_active_clusters, + stream_fake, + swap_ab, + options="--opt-level 2 --enable-tvm-ffi", + ) + + _CUTE_DSL_MM_FP4_KERNEL_CACHE[cache_key] = ( + compiled_gemm, + max_active_clusters, + ) + + compiled_gemm, max_active_clusters = _CUTE_DSL_MM_FP4_KERNEL_CACHE[ + cache_key + ] + + # Handle output tensor for swap_ab + if swap_ab: + launch_out = out.T + else: + launch_out = out + + # Prepare alpha: ensure it is always a 1-dim tensor with shape [1]. + if alpha_tensor is None: + alpha_for_launch = torch.tensor( + [1.0], dtype=torch.float32, device=a.device + ) + elif alpha_tensor.dim() == 0: + alpha_for_launch = alpha_tensor.unsqueeze(0) + else: + alpha_for_launch = alpha_tensor.reshape(1) + + # Launch via TVM-FFI: + # - A, B, C, alpha: torch.Tensor directly (C-level dlpack, negligible cost) + # - SF pointers: data_ptr() ints (complex 6D layout) + # - Stream: automatic (fake_stream) + compiled_gemm( + kernel_a, + kernel_b, + launch_out, + sf_m, + sf_n, + sf_k, + kernel_a_sf.data_ptr(), + kernel_b_sf.data_ptr(), + alpha_for_launch, + ) + + return out + + return CuteDSLFp4GemmRunner() + + def _heuristic_func_mm_fp4( suitable_backends: List[str], a: torch.Tensor, @@ -3131,8 +3623,9 @@ def _heuristic_func_mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "cudnn", + backend: Literal["cudnn", "trtllm", "cutlass", "cute-dsl", "auto"] = "cudnn", use_nvfp4: bool = True, + enable_pdl: bool = True, # unused ): r""" Heuristic function for mm_fp4 backend selection. Routes to either cudnn or cutlass. @@ -3261,6 +3754,7 @@ def _mxfp8_swizzled_scale_len(m: int, k: int) -> int: "cudnn": _cudnn_gemm_fp4_requirement, "trtllm": _trtllm_gemm_fp4_requirement, "cutlass": _cutlass_gemm_fp4_requirement, + "cute-dsl": _cute_dsl_gemm_fp4_requirement, }, common_check=_check_mm_fp4_problem_size, heuristic_func=_heuristic_func_mm_fp4, # result stored in mm_fp4.suitable_auto_backends @@ -3276,8 +3770,9 @@ def mm_fp4( out: Optional[torch.Tensor] = None, block_size: int = 16, use_8x4_sf_layout: bool = False, - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto", + backend: Literal["cudnn", "trtllm", "cutlass", "cute-dsl", "auto"] = "auto", use_nvfp4: bool = True, + enable_pdl: bool = True, ) -> torch.Tensor: r"""MM FP4 @@ -3310,20 +3805,27 @@ def mm_fp4( use_8x4_sf_layout: bool Whether to use 8x4 scale factor layout or 128x4 scale factor layout, defaults to False. - backend: Literal["cudnn", "trtllm", "cutlass", "auto"] + backend: Literal["cudnn", "trtllm", "cutlass", "cute-dsl", "auto"] Backend to use, defaults to ``"auto"``, which automatically selects the best backend between ``"cudnn"`` and ``"cutlass"`` based on the current CUDA and - cuDNN versions. The ``"trtllm"`` backend is never selected when - ``backend="auto"`` because it requires different weight preparation. + cuDNN versions. The ``"trtllm"`` and ``"cute-dsl"`` backends are never selected + when ``backend="auto"`` because they require different weight preparation. use_nvfp4: bool Whether to use nvfp4 quantization or mxfp4 quantization, defaults to ``True``. See the ``block_size`` parameter for related constraints. + enable_pdl: bool + Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl`` + backend, defaults to ``True``. PDL allows overlapping the tail of one kernel + with the start of the next for reduced launch latency. This parameter is + only used by the ``cute_dsl`` backend and is ignored by other backends. + Notes ----- When cudnn/cutlass backend is used, both a and b should quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False. When trtllm backend is used, b must be quantized with 128x4 layout and `do_shuffle=True`. a can be quantized with either 128x4 or 8x4 layout (controlled by `use_8x4_sf_layout`) and `do_shuffle=False`. + When cute_dsl backend is used, both a and b should be quantized with nvfp4_quantize using the 128x4 scale factor layout and do_shuffle=False (same as cudnn/cutlass). Returns ------- @@ -3375,6 +3877,9 @@ def mm_fp4( "cutlass": lambda: get_cutlass_fp4_gemm_module( major, minor ).cutlass_fp4_gemm_runner(), + "cute-dsl": lambda: _cute_dsl_gemm_fp4_runner( + major, minor, enable_pdl, out_dtype + ), } runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py new file mode 100644 index 0000000000..87c1325b69 --- /dev/null +++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm100.py @@ -0,0 +1,2191 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This file is ported from TensorRT-LLM's dense_blockscaled_gemm_persistent.py +# with modifications for FlashInfer integration. +# Original: https://github.com/NVIDIA/TensorRT-LLM + +from typing import Optional, Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 + +from cutlass.cute.arch import griddepcontrol_launch_dependents, griddepcontrol_wait +from cutlass.pipeline import PipelineTmaUmma, PipelineUmmaAsync + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """Implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and Blackwell GPU architectural features, including persistent tile scheduling and warp specialization. + + Notes: + - In the current version, tensors A and B must have the same data type. + For example, using Float8E4M3FN for A and Float8E5M2 for B is not supported. + - Supported combinations of A/B data types, scale factor (SF) data types, and SF vector size: + * MXF8: A/B: Float8E5M2/Float8E4M3FN, SF: Float8E8M0FNU, sf_vec_size: 32 + * MXF4: A/B: Float4E2M1FN, SF: Float8E8M0FNU, sf_vec_size: 32 + * NVF4: A/B: Float4E2M1FN, SF: Float8E8M0FNU/Float8E4M3FN, sf_vec_size: 16 + - Supported accumulator data types: + * Float32 + - Supported C data types: + * Float32 + * Float16/BFloat16 + * Float8E4M3FN/Float8E5M2 + - Constraints: + * MMA tiler M must be 128 or 256 (use_2cta_instrs). + * MMA tiler N must be 64/128/192/256 + * Cluster shape M must be a multiple of 2 if MMA tiler M is 256. + * Cluster shape M/N must be positive and a power of 2, with total cluster size <= 16. + * Cluster shape M/N must be <= 4 for scale factor multicasts due to limited scale factor size. + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_prefetch: bool = False, + enable_pdl: bool = True, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + Args: + sf_vec_size (int): Scalefactor vector size. + mma_tiler_mn (Tuple[int, int]): Shape of the Matrix Multiply-Accumulate (MMA) tile (M, N). + cluster_shape_mn (Tuple[int, int]): Cluster dimensions (M, N) for parallel processing. + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_prefetch = use_prefetch + self.enable_pdl = enable_pdl + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + - Computing tensor memory allocation columns + """ + # Compute mma instruction shapes + mma_inst_bits_k = 256 + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mnk = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_bits_k // self.a_dtype.width, + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mnk_sfb = ( + self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mnk[1], 128), + self.mma_inst_shape_mnk[2], + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mnk[0], + self.mma_inst_shape_mnk[1], + self.mma_inst_shape_mnk[2] * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mnk_sfb[0], + self.mma_inst_shape_mnk_sfb[1], + self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + self.epi_tile_n = cute.size(self.epi_tile[1]) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.a_major_mode, + self.b_dtype, + self.b_major_mode, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + self.overlapping_accum = self.num_acc_stage == 1 + sf_atom_mn = 32 + self.num_sfa_tmem_cols = ( + self.cta_tile_shape_mnk[0] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sfb_tmem_cols = ( + self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = ( + self.cta_tile_shape_mnk[1] * self.num_acc_stage + if not self.overlapping_accum + else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + ) + + # Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue + self.iter_acc_early_release_in_epilogue = ( + self.num_sf_tmem_cols // self.epi_tile_n + ) + + # TODO: [alel] Currently set prefetch dist to num_ab_stage, we may have more options for prefetch dist auto tuning + self.prefetch_dist = self.num_ab_stage + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + alpha: cute.Tensor, # Single-element tensor containing alpha value + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + Args: + a_tensor (cute.Tensor): Input tensor A + b_tensor (cute.Tensor): Input tensor B + sfa_tensor (cute.Tensor): Scale factor tensor A + sfb_tensor (cute.Tensor): Scale factor tensor B + c_tensor (cute.Tensor): Output tensor C + max_active_clusters (cutlass.Constexpr): Maximum number of active clusters + stream (cuda.CUstream): CUDA stream for asynchronous execution + epilogue_op (cutlass.Constexpr): Optional elementwise lambda function to apply to the output tensor + + Raises: + TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, self.sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, self.sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor( + tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + alpha, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), # type: ignore[attr-defined] + min_blocks_per_mp=1, + stream=stream, + use_pdl=self.enable_pdl, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + alpha: cute.Tensor, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + # Keep alpha in FP32 for precision: the accumulator is in FP32 and alpha + # may be a very small scaling factor. Converting to c_dtype (e.g., FP16) + # before multiplication could cause overflow when acc values are large. + alpha_value = alpha[0].to(cutlass.Float32) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed(aligned=True) + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_block_cnt = cutlass.Int32(cute.size(gA_mkl, mode=[3])) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + if cutlass.const_expr(self.overlapping_accum): + num_acc_stage_overlapped = 2 + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, num_acc_stage_overlapped) + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + else: + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # PDL bookend: griddepcontrol_wait is always emitted; the actual PDL + # behavior is controlled by use_pdl= in .launch(). The griddepcontrol + # instructions are effectively no-ops when PDL is not enabled at + # the driver level. + griddepcontrol_wait() + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] + + if cutlass.const_expr(self.use_prefetch): + # Prefetch both A and B (default behavior) + for pf_k_block in cutlass.range( + 0, min(self.prefetch_dist, k_block_cnt), unroll=1 + ): + cute.prefetch( + tma_atom_a, + tAgA_slice[(None, pf_k_block)], + ) + cute.prefetch( + tma_atom_b, + tBgB_slice[(None, pf_k_block)], + ) + cute.prefetch( + tma_atom_sfa, + tAgSFA_slice[(None, pf_k_block)], + ) + cute.prefetch( + tma_atom_sfb, + tBgSFB_slice[(None, pf_k_block)], + ) + + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_block in cutlass.range(0, k_block_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Prefetch: Rolling prefetch for next tiles + if cutlass.const_expr(self.use_prefetch): + if k_block < k_block_cnt - self.prefetch_dist: + future_k_block = ( + ab_producer_state.count + self.prefetch_dist + ) + # Prefetch both A and B (default behavior) + cute.prefetch( + tma_atom_a, + tAgA_slice[(None, future_k_block)], + ) + cute.prefetch( + tma_atom_b, + tBgB_slice[(None, future_k_block)], + ) + cute.prefetch( + tma_atom_sfa, + tAgSFA_slice[(None, future_k_block)], + ) + cute.prefetch( + tma_atom_sfb, + tBgSFB_slice[(None, future_k_block)], + ) + + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_block_cnt + k_block + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_block_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + # Make accumulator tmem tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_producer_state.phase ^ 1 + else: + acc_stage_index = acc_producer_state.index + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + # Peek (try_wait) AB buffer full for k_block = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_block_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) + offset = ( + cutlass.Int32(2) + if mma_tile_coord_mnl[1] % 2 == 1 + else cutlass.Int32(0) + ) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + self.num_accumulator_tmem_cols + + self.num_sfa_tmem_cols + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + self.num_accumulator_tmem_cols + + self.num_sfa_tmem_cols + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for _k_block in range(k_block_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kphases = cute.size(tCrA, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord = ( + None, + None, + kphase_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kphase_coord = (None, None, kphase_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kphase_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB_mma[sf_kphase_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kphase_coord], + tCrB[kphase_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kphase + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_block = k_block + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_block_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = ( + cutlass.Boolean(True) + if acc_stage_index == 0 + else cutlass.Boolean(False) + ) + else: + acc_stage_index = acc_consumer_state.index + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_stage_index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in sub-tiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + + for subtile_idx in cutlass.range(subtile_cnt): + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = ( + self.cta_tile_shape_mnk[1] // self.epi_tile_n + - 1 + - subtile_idx + ) + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Async arrive accumulator buffer empty earlier when overlapping_accum is enabled + # + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + # Multiply alpha in FP32 before converting to c_dtype to + # avoid overflow when c_dtype is FP16 and acc values are large. + acc_vec = epilogue_op((alpha_value * acc_vec).to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, real_subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # Async arrive accumulator buffer empty + # + if cutlass.const_expr(not self.overlapping_accum): + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + griddepcontrol_launch_dependents() + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + Args: + sSF (cute.Tensor): The scale factor tensor in smem + tSF (cute.Tensor): The scale factor tensor in tmem + + Returns: + A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + Args: + tidx (cutlass.Int32): The thread index in epilogue warp groups + tAcc (cute.Tensor): The accumulator tensor to be copied and partitioned + gC_mnl (cute.Tensor): The global tensor C + epi_tile (cute.Tile): The epilogue tiler + use_2cta_instrs (bool): Whether use_2cta_instrs is enabled + + Returns: + A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + Args: + tiled_copy_t2r (cute.TiledCopy): The tiled copy operation for tmem to register copy(t2r) + tTR_rC (cute.Tensor): The partitioned accumulator tensor + tidx (cutlass.Int32): The thread index in epilogue warp groups + sC (cute.Tensor): The shared memory tensor to be copied and partitioned + sepi (cute.Tensor): + + Returns: + A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + Args: + tidx (cutlass.Int32): The thread index in epilogue warp groups + atom (Union[cute.CopyAtom, cute.TiledCopy]): The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + gC_mnl (cute.Tensor): The global tensor C + epi_tile (cute.Tile): The epilogue tiler + sC (cute.Tensor): The shared memory tensor to be copied and partitioned + + Returns: + A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + a_major_mode: tcgen05.OperandMajorMode, + b_dtype: Type[cutlass.Numeric], + b_major_mode: tcgen05.OperandMajorMode, + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the optimal number of pipeline stages for operands. + + This method uses heuristics to determine the number of stages for the + accumulator (ACC), A/B operands, and C operand based on the kernel + configuration and available shared memory. + + Args: + tiled_mma (cute.TiledMma): The tiled MMA object. + mma_tiler_mnk (Tuple[int, int, int]): The shape (M, N, K) of the + MMA tiler. + a_dtype (Type[cutlass.Numeric]): Data type of operand A. + a_major_mode (tcgen05.OperandMajorMode): Layout of operand A. + b_dtype (Type[cutlass.Numeric]): Data type of operand B. + b_major_mode (tcgen05.OperandMajorMode): Layout of operand B. + epi_tile (cute.Tile): The epilogue tile shape. + c_dtype (Type[cutlass.Numeric]): Data type of operand C. + c_layout (utils.LayoutEnum): Layout of operand C. + sf_dtype (Type[cutlass.Numeric]): Data type of the scale factors. + sf_vec_size (int): Vector size of the scale factors. + smem_capacity (int): Total available shared memory in bytes. + occupancy (int): Target number of CTAs per SM. + + Returns: + Tuple[int, int, int]: A tuple containing the number of stages for + (ACC, A/B, C). + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Computes the grid size for the kernel launch. + + This method uses a persistent tile scheduler to determine the grid + dimensions based on the output tensor shape and hardware constraints. + + Args: + c (cute.Tensor): The output tensor C. + cta_tile_shape_mnk (Tuple[int, int, int]): The shape (M, N, K) of the + CTA tile. + cluster_shape_mn (Tuple[int, int]): The shape of a CTA cluster. + max_active_clusters (cutlass.Constexpr): The maximum number of + active clusters supported by the hardware. + + Returns: + A tuple containing the tile scheduler parameters and the computed + grid shape. + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """Checks if the data types and scale factor vector size are a valid combination. + + Args: + ab_dtype (Type[cutlass.Numeric]): Data type of operands A and B. + sf_dtype (Type[cutlass.Numeric]): Data type of the scale factors. + sf_vec_size (int): Vector size of the scale factors. + c_dtype (Type[cutlass.Numeric]): Data type of the output tensor C. + + Returns: + bool: True if the combination is valid, False otherwise. + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """Checks if the tensor layouts are valid for the given data types. + + Args: + ab_dtype (Type[cutlass.Numeric]): Data type of operands A and B. + c_dtype (Type[cutlass.Numeric]): Data type of the output tensor C. + a_major (str): The major layout of tensor A ('k' or 'm'). + b_major (str): The major layout of tensor B ('k' or 'n'). + c_major (str): The major layout of tensor C ('n' or 'm'). + + Returns: + bool: True if the layouts are valid, False otherwise. + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """Checks if the MMA tiler and cluster shape are a valid combination. + + Args: + mma_tiler_mn (Tuple[int, int]): The (M, N) shape of the MMA tiler. + cluster_shape_mn (Tuple[int, int]): The (M, N) shape of the CTA cluster. + + Returns: + bool: True if the combination is valid, False otherwise. + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [64, 128, 192, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + _is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not _is_power_of_2(cluster_shape_mn[0]) + or not _is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: cutlass.Int64, + n: cutlass.Int64, + k: cutlass.Int64, + l: cutlass.Int64, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """Checks if the tensor dimensions are valid for memory alignment. + + Args: + m (cutlass.Int64): The M dimension of the GEMM problem. + n (cutlass.Int64): The N dimension of the GEMM problem. + k (cutlass.Int64): The K dimension of the GEMM problem. + l (cutlass.Int64): The batch dimension (L) of the GEMM problem. + ab_dtype (Type[cutlass.Numeric]): Data type of operands A and B. + c_dtype (Type[cutlass.Numeric]): Data type of the output tensor C. + a_major (str): The major layout of tensor A ('k' or 'm'). + b_major (str): The major layout of tensor B ('k' or 'n'). + c_major (str): The major layout of tensor C ('n' or 'm'). + + Returns: + bool: True if the tensor alignment is valid, False otherwise. + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @classmethod + def can_implement( + cls, + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: cutlass.Int64, + n: cutlass.Int64, + k: cutlass.Int64, + l: cutlass.Int64, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """Checks if the kernel can implement the given GEMM problem. + + This method aggregates checks for data types, layouts, tile shapes, + and alignment to determine if the configuration is supported. + + Args: + ab_dtype (Type[cutlass.Numeric]): Data type of operands A and B. + sf_dtype (Type[cutlass.Numeric]): Data type of the scale factors. + sf_vec_size (int): Vector size of the scale factors. + c_dtype (Type[cutlass.Numeric]): Data type of the output tensor C. + mma_tiler_mn (Tuple[int, int]): The (M, N) shape of the MMA tiler. + cluster_shape_mn (Tuple[int, int]): The (M, N) shape of the CTA + cluster. + m (cutlass.Int64): The M dimension of the GEMM problem. + n (cutlass.Int64): The N dimension of the GEMM problem. + k (cutlass.Int64): The K dimension of the GEMM problem. + l (cutlass.Int64): The batch dimension (L) of the GEMM problem. + a_major (str): The major layout of tensor A ('k' or 'm'). + b_major (str): The major layout of tensor B ('k' or 'n'). + c_major (str): The major layout of tensor C ('n' or 'm'). + + Returns: + bool: True if the configuration is supported, False otherwise. + """ + can_implement = True + # Skip unsupported types + if not cls.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major, c_major): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn, cluster_shape_mn): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not cls.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + # fully dynamic shape + @cute.jit + def wrapper( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sf_m: cutlass.Int64, + sf_n: cutlass.Int64, + sf_k: cutlass.Int64, + l: cutlass.Constexpr, + a_sf_ptr: cute.Pointer, + b_sf_ptr: cute.Pointer, + alpha_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + current_stream, + swap_ab: cutlass.Constexpr = False, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Executes the wrapped GEMM kernel with dynamically shaped tensors. + + Uses TVM-FFI for efficient tensor passing: A, B, C, and alpha are passed + as cute.Tensor directly (torch tensors at runtime via TVM-FFI's C-level + dlpack, with negligible conversion cost). Scale factor tensors remain as + pointers because their 6D BlockScaledBasicChunk layout cannot be expressed + as a torch tensor. + + Args: + mA (cute.Tensor): Input tensor A, shape (m, k_packed), K-major. + mB (cute.Tensor): Input tensor B, shape (n, k_packed), K-major. + mC (cute.Tensor): Output tensor C, shape (m, n). + sf_m (cutlass.Int64): Scale factor M dim (ceil(m/128)). + sf_n (cutlass.Int64): Scale factor N dim (ceil(n/128)). + sf_k (cutlass.Int64): Scale factor K dim (ceil(k/sf_vec_size/4)). + l (cutlass.Constexpr): Batch dimension (L). + a_sf_ptr (cute.Pointer): Pointer to scale factor tensor for A. + b_sf_ptr (cute.Pointer): Pointer to scale factor tensor for B. + alpha_tensor (cute.Tensor): Alpha scaling factor, shape (1,), float32. + max_active_clusters (cutlass.Constexpr): Max active clusters. + current_stream: CUDA stream (managed by TVM-FFI fake stream). + swap_ab (cutlass.Constexpr): Whether A/B are swapped (controls C layout). + epilogue_op (cutlass.Constexpr): Elementwise epilogue function. + """ + # A, B, C are passed as cute.Tensor via TVM-FFI. + # A/B come in as Uint8 (FP4 packed as uint8 in torch). Recast to FP4. + m = cute.size(mA, mode=[0]) + k_packed = cute.size(mA, mode=[1]) + n = cute.size(mB, mode=[0]) + # k in FP4 elements = k_packed * 2 (2 FP4 values per uint8 byte) + k = k_packed * 2 + + # Recast Uint8 → Float4E2M1FN and reshape to (m, k, l) with K-major order + a_fp4_ptr = cute.recast_ptr(mA.iterator, dtype=cutlass.Float4E2M1FN) + a_tensor = cute.make_tensor( + a_fp4_ptr, + layout=cute.make_ordered_layout((m, k, l), order=(1, 0, 2)), + ) + # Recast B and reshape to (n, k, l) with K-major order + b_fp4_ptr = cute.recast_ptr(mB.iterator, dtype=cutlass.Float4E2M1FN) + b_tensor = cute.make_tensor( + b_fp4_ptr, + layout=cute.make_ordered_layout( + (n, k, l), + order=(1, 0, 2), + ), + ) + # Reshape C to (m, n, l) -- swap_ab is constexpr, determines layout at compile time + if cutlass.const_expr(swap_ab): + c_tensor = cute.make_tensor( + mC.iterator, + layout=cute.make_ordered_layout((m, n, l), order=(0, 1, 2)), + ) + else: + c_tensor = cute.make_tensor( + mC.iterator, + layout=cute.make_ordered_layout((m, n, l), order=(1, 0, 2)), + ) + # Scale factor tensors: 6D BlockScaledBasicChunk layout from pointers + # (32, 4, sf_m, 4, sf_k, l) with order (2, 1, 4, 0, 3, 5) + sfa_tensor = cute.make_tensor( + a_sf_ptr, + layout=cute.make_ordered_layout( + (32, 4, sf_m, 4, sf_k, l), + order=(2, 1, 4, 0, 3, 5), + ), + ) + sfb_tensor = cute.make_tensor( + b_sf_ptr, + layout=cute.make_ordered_layout( + (32, 4, sf_n, 4, sf_k, l), + order=(2, 1, 4, 0, 3, 5), + ), + ) + + self( + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + alpha_tensor, + max_active_clusters, + current_stream, + epilogue_op, + ) diff --git a/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py new file mode 100644 index 0000000000..638aa1922c --- /dev/null +++ b/flashinfer/gemm/kernels/dense_blockscaled_gemm_sm103.py @@ -0,0 +1,2568 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# This file is ported from CUTLASS's sm103_dense_blockscaled_gemm_persistent.py +# with modifications for FlashInfer integration (alpha scaling, PDL support, wrapper method). +# Original: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py + +from typing import Type, Tuple, Union +from dataclasses import dataclass, field + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm103_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.arch import griddepcontrol_launch_dependents, griddepcontrol_wait + + +class Sm103BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for FP4 data types + and architectural features specific to Blackwell SM103 GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + + :note: In current version, A and B tensor must have the same data type + - i.e., Float4E2M1FN for A and Float4E2M1FN for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 128/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm103BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 256), + ... cluster_shape_mn=(2, 4) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + enable_pdl: bool = True, + ): + """Initializes the configuration for a Blackwell SM103 3xFP4 GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + :param enable_pdl: Whether Programmatic Dependent Launch is enabled. + :type enable_pdl: bool + """ + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + self.enable_pdl = enable_pdl + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilogue_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_ab_warp_id = 5 + self.tma_sf_warp_id = 6 + self.threads_per_cta = 32 * len( + ( + self.mma_warp_id, + self.tma_ab_warp_id, + self.tma_sf_warp_id, + *self.epilogue_warp_id, + ) + ) + # Set barrier id for epilogue sync and tmem ptr sync + self.epilog_sync_bar_id = 1 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_103") + # SM103 TMEM capacity is 512 columns (same as SM100). + # This replaces cute.arch.get_max_tmem_alloc_cols("sm_103") which + # may not be available in older cutlass-dsl versions. + SM103_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM103_TMEM_CAPACITY_COLUMNS + self.sf_buffers_per_tile_k = 4 if self.sf_vec_size == 16 else 2 + + def _setup_attributes(self): + """Set up kernel attributes that depend on runtime tensor inputs. + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + """ + # Compute mma instruction shapes + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) + + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = self.sm103_make_blockscaled_trivial_tiled_mma( + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + dummy_tiled_mma_sfb = self.sm103_make_blockscaled_trivial_tiled_mma( + self.sf_dtype, + self.sf_vec_size, + tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + 768, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_layout_vmnk.shape[0]), + self.mma_tiler[1], + self.mma_tiler[2], + ) + blk_mn = 128 + self.cta_n_sf = cute.round_up(cute.size(self.cta_tile_shape_mnk[1]), blk_mn) + self.mma_sf_tiler = ( + self.cta_tile_shape_mnk[0], + self.cta_n_sf, + self.cta_tile_shape_mnk[2] // self.sf_buffers_per_tile_k, + ) + + self.sf_atom = self.Sm103BlockScaledBasicChunk( + self.sf_vec_size, tiled_mma.op.a_major_mode + ).layout + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (dummy_tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = (self.cta_tile_shape_mnk[0], 64) + + self.num_acc_stage, self.num_ab_stage, self.num_sf_stage, self.num_c_stage = ( + self._compute_stages( + tiled_mma, + self.mma_tiler, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + ) + + # Compute A/B/SFA/SFB/C shared memory layout + # ((CTA_MMA_M,16bytes),1,8,num_ab_stage) + self.a_smem_layout_staged = self.sm103_make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.num_ab_stage, + ) + + # ((CTA_MMA_M,16bytes),1,8,3) + self.a_smem_layout_staged_tma = self.sm103_make_smem_layout_a( + tiled_mma, + self.mma_tiler, + 3, + ) + + # ((CTA_MMA_N,16bytes),1,8,num_ab_stage) + self.b_smem_layout_staged = self.sm103_make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.num_ab_stage, + ) + + # ((CTA_MMA_N,16bytes),1,8,3) + self.b_smem_layout_staged_tma = self.sm103_make_smem_layout_b( + tiled_mma, + self.mma_tiler, + 3, + ) + + # (((8,4,4),(sf_vec_size,4)),1,3,num_sf_stage) + self.sfa_smem_layout_staged = self.sm103_make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_sf_stage, + ) + + # (((32,4,2),(sf_vec_size,4)),1,3,num_sf_stage) + self.sfb_smem_layout_staged = self.sm103_make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_sf_stage, + ) + self.c_smem_layout_staged = None + if self.use_tma_store: + self.c_smem_layout_staged = sm103_utils.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage + ) + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + alpha: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param alpha: Single-element tensor containing alpha scaling value + :type alpha: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + sfa_layout = cute.tile_to_shape(self.sf_atom, a_tensor.shape, (2, 1, 3)) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + sfb_layout = cute.tile_to_shape(self.sf_atom, b_tensor.shape, (2, 1, 3)) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = self.sm103_make_blockscaled_trivial_tiled_mma( + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + dummy_tiled_mma_sfb = self.sm103_make_blockscaled_trivial_tiled_mma( + self.sf_dtype, + self.sf_vec_size, + tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm103_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + # casting layout as uint8 for multicast + a_smem_layout_tma_ready = self.adapt_layout_for_tma_ab( + self.a_smem_layout_staged_tma + ) + a_tensor_uint8 = cute.recast_tensor(a_tensor, cutlass.Uint8) + tma_atom_a, tma_tensor_a = cute.nvgpu.cpasync.make_tiled_tma_atom( + a_op, + a_tensor_uint8, + a_smem_layout_tma_ready, + # 384 corresponds to the number of uint8 elements along the K dimension processed in a single MMA mainloop iteration. + (cute.size(tiled_mma.tv_layout_A[1][0]), 384), + self.cluster_shape_mn[1], + internal_type=cutlass.Uint8, + ) + + # Setup TMA load for B + b_op = sm103_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + # casting layout as uint8 for multicast + b_smem_layout_tma_ready = self.adapt_layout_for_tma_ab( + self.b_smem_layout_staged_tma + ) + b_tensor_uint8 = cute.recast_tensor(b_tensor, cutlass.Uint8) + tma_atom_b, tma_tensor_b = cute.nvgpu.cpasync.make_tiled_tma_atom( + b_op, + b_tensor_uint8, + b_smem_layout_tma_ready, + (cute.size(tiled_mma.tv_layout_B[1][0]), 384), + self.cluster_shape_mn[0] // cute.size(tiled_mma.thr_id.shape), + internal_type=cutlass.Uint8, + ) + + # Setup TMA load for SFA + sfa_op = sm103_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + sfa_smem_layout_tma_ready = self.adapt_layout_for_tma_sf(sfa_smem_layout) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.cpasync.make_tiled_tma_atom( + sfa_op, + sfa_tensor, + sfa_smem_layout_tma_ready, + (self.mma_sf_tiler[0], self.mma_sf_tiler[2]), + self.cluster_shape_mn[1], + internal_type=cutlass.Uint8, + ) + + # Setup TMA load for SFB + sfb_op = sm103_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + sfb_smem_layout_tma_ready = self.adapt_layout_for_tma_sf(sfb_smem_layout) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.cpasync.make_tiled_tma_atom( + sfb_op, + sfb_tensor, + sfb_smem_layout_tma_ready, + (self.mma_sf_tiler[1], self.mma_sf_tiler[2]), + self.cluster_shape_mn[0] // cute.size(dummy_tiled_mma_sfb.thr_id), + internal_type=cutlass.Uint8, + ) + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_smem_layout, + self.epi_tile, + ) + + a_copy_size = cute.size_in_bytes( + cutlass.Uint8, + cute.slice_(self.a_smem_layout_staged_tma, (None, None, None, 0)), + ) + b_copy_size = cute.size_in_bytes( + cutlass.Uint8, + cute.slice_(self.b_smem_layout_staged_tma, (None, None, None, 0)), + ) + sfa_copy_size = cute.size_in_bytes( + cutlass.Uint8, + cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)), + ) + sfb_copy_size = cute.size_in_bytes( + cutlass.Uint8, + cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)), + ) + self.num_tma_load_bytes_ab = (a_copy_size + b_copy_size) * atom_thr_size + self.num_tma_load_bytes_sf = (sfa_copy_size + sfb_copy_size) * atom_thr_size + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + sf_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_sf_stage] + sf_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_sf_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + cutlass.Uint8, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c_tensor, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + alpha, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.enable_pdl, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + alpha: cute.Tensor, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + # Keep alpha in FP32 for precision: the accumulator is in FP32 and alpha + # may be a very small scaling factor. Converting to c_dtype (e.g., FP16) + # before multiplication could cause overflow when acc values are large. + alpha_value = alpha[0].to(cutlass.Float32) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_ab_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + if warp_idx == self.tma_sf_warp_id: + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, sfa+sfb full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize mainloop ab_producer and ab_consumer + ab_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_producer_group, + consumer_group=ab_consumer_group, + tx_count=self.num_tma_load_bytes_ab, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + # Initialize mainloop sf_producer and sf_consumer + sf_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_sf_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + sf_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_sf_tma_producer + ) + sf_producer, sf_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.sf_full_mbar_ptr.data_ptr(), + num_stages=self.num_sf_stage, + producer_group=sf_producer_group, + consumer_group=sf_consumer_group, + tx_count=self.num_tma_load_bytes_sf, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilogue_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), + ) + tmem_dealloc_barrier = None + if cutlass.const_expr(not self.use_tma_store): + tmem_dealloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_dealloc_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/SFA/SFB/C + # + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (BLK_M, BLK_K, m, k, l) + gA_mkl = cute.local_tile( + mA_mkl, + cute.slice_((self.mma_tiler[0], self.mma_tiler[1], 384), (None, 0, None)), + (None, None, None), + ) + # (BLK_N, BLK_K, n, k, l) + gB_nkl = cute.local_tile( + mB_nkl, + cute.slice_((self.mma_tiler[0], self.mma_tiler[1], 384), (0, None, None)), + (None, None, None), + ) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.mma_sf_tiler, (None, 0, None)), + (None, None, None), + ) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_sf_tiler, (0, None, None)), + (None, None, None), + ) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + + # create tCgA_tmp + tCgA_mkl_tmp = thr_mma.partition_A(gA_mkl) + tCgA_layout = self.append_coalesce_layout(tCgA_mkl_tmp.layout) + cta_tCgA = cute.make_tensor(tCgA_mkl_tmp.iterator, tCgA_layout) + # ((CTA_MMA_M,256),Rest_MMA_M,Rest_MMA_K, m, k, l) + tCgA = cute.make_tensor( + cta_tCgA.iterator, + cute.tiled_divide( + cta_tCgA.layout, (cute.size(tiled_mma.tv_layout_A[1][0]), 128) + ), + ) + + tCgB_nkl_tmp = thr_mma.partition_B(gB_nkl) + tCgB_layout = self.append_coalesce_layout(tCgB_nkl_tmp.layout) + cta_tCgB = cute.make_tensor(tCgB_nkl_tmp.iterator, tCgB_layout) + # ((CTA_MMA_N,256),Rest_MMA_N, Rest_MMA_K, n, k, l) + tCgB = cute.make_tensor( + cta_tCgB.iterator, + cute.tiled_divide( + cta_tCgB.layout, (cute.size(tiled_mma.tv_layout_B[1][0]), 128) + ), + ) + + tCgSFA = cute.make_tensor( + gSFA_mkl.iterator, + cute.tiled_divide( + gSFA_mkl.layout, (self.mma_sf_tiler[0], self.mma_sf_tiler[2]) + ), + ) + + tCgSFB = cute.make_tensor( + gSFB_nkl.iterator, + cute.tiled_divide( + gSFB_nkl.layout, (self.mma_sf_tiler[1], self.mma_sf_tiler[2]) + ), + ) + tCgC = thr_mma.partition_C(gC_mnl) + + # Create identity tensor for C to use in epilogue predication + idC = cute.make_identity_tensor(mC_mnl.shape) + cC_mnl = cute.local_tile( + idC, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCcC = thr_mma.partition_C(cC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 1), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 1), + ) + + # TMA partition for scale factor A + sfa_cta_layout = a_cta_layout + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA_compact = cute.filter_zeros(tAsSFA) + + # TMA partition for scale factor B + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB_compact = cute.filter_zeros(tBsSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # PDL bookend: always emitted, actual behavior controlled by use_pdl= in .launch() + griddepcontrol_wait() + + # + # Construct the scheduler + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # + # Specialized TMA load warp for A/B tensors + # + if warp_idx == self.tma_ab_warp_id: + # + # Persistent tile scheduling loop for AB loads + # + buffers_per_k_tile = 3 + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + tAgA_slice = tAgA[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + None, + mma_tile_coord_mnl[2], + ) + ] + tBgB_slice = tBgB[ + ( + None, + None, + None, + mma_tile_coord_mnl[1], + None, + mma_tile_coord_mnl[2], + ) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer.reset() + peek_ab_empty_status = cutlass.Boolean(1) + peek_ab_empty_status = ab_producer.try_acquire() + + # + # TMA load loop for A/B tensors + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Load buffers_per_k_tile buffers + for buffer in cutlass.range(buffers_per_k_tile, unroll_full=True): + # Acquire next empty AB buffer + ab_empty = ab_producer.acquire_and_advance(peek_ab_empty_status) + + # TMA load A/B + cute.copy( + tma_atom_a, + cute.group_modes( + tAgA_slice[(None, None, buffer, k_tile)], 0, 2 + ), + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + cute.group_modes( + tBgB_slice[(None, None, buffer, k_tile)], 0, 2 + ), + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for next buffer + peek_ab_empty_status = cutlass.Boolean(1) + # Check if we're not at the last buffer of the last k_tile + if not ( + (k_tile == k_tile_cnt - 1) + and (buffer == buffers_per_k_tile - 1) + ): + peek_ab_empty_status = ab_producer.try_acquire() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Signal end of AB loads + ab_producer.tail() + + # + # Specialized TMA load warp for scale factor tensors + # + if warp_idx == self.tma_sf_warp_id: + # + # Persistent tile scheduling loop for SF loads + # + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0], + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + tBgSFB_slice = tBgSFB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) SF buffer empty + sf_producer.reset() + peek_sf_empty_status = cutlass.Boolean(1) + peek_sf_empty_status = sf_producer.try_acquire() + + # + # TMA load loop for scale factors + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Load SF stages based on sf_buffers_per_tile_k + for sf_stage in cutlass.range( + self.sf_buffers_per_tile_k, unroll_full=True + ): + # Acquire next empty SF buffer + sf_empty = sf_producer.acquire_and_advance(peek_sf_empty_status) + + tAgSFA_compact = cute.filter_zeros( + tAgSFA_slice[ + (None, k_tile * self.sf_buffers_per_tile_k + sf_stage) + ] + ) + tBgSFB_compact = cute.filter_zeros( + tBgSFB_slice[ + (None, k_tile * self.sf_buffers_per_tile_k + sf_stage) + ] + ) + + # TMA load SFA/SFB for this SF stage + cute.copy( + tma_atom_sfa, + tAgSFA_compact, + tAsSFA_compact[(None, sf_empty.index)], + tma_bar_ptr=sf_empty.barrier, + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_compact, + tBsSFB_compact[(None, sf_empty.index)], + tma_bar_ptr=sf_empty.barrier, + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) SF buffer empty for next stage + peek_sf_empty_status = cutlass.Boolean(1) + # Check if we're not at the last stage of the last k_tile + if not ( + k_tile == k_tile_cnt - 1 + and sf_stage == self.sf_buffers_per_tile_k - 1 + ): + peek_sf_empty_status = sf_producer.try_acquire() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Signal end of SF loads + sf_producer.tail() + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # Make accumulator tmem tensor + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + + MMA_M = self.cta_tile_shape_mnk[0] + MMA_N_SF = self.cta_n_sf + MMA_K_SF = self.cta_tile_shape_mnk[2] // 2 + mnBasicBlockShape = (32, 4) + kBasicBlockShape_single = (self.sf_vec_size, 1) + mma_iter_SFA_shape = ( + (mnBasicBlockShape, MMA_M // 128), + kBasicBlockShape_single, + ) + sSFA_iter_shape = (mma_iter_SFA_shape, 1, MMA_K_SF // self.sf_vec_size) + sSFA_iter_layout = cute.make_layout(sSFA_iter_shape) + mma_iter_SFB_shape = ( + (mnBasicBlockShape, MMA_N_SF // 128), + kBasicBlockShape_single, + ) + sSFB_iter_shape = (mma_iter_SFB_shape, 1, MMA_K_SF // self.sf_vec_size) + sSFB_iter_layout = cute.make_layout(sSFB_iter_shape) + + tCtSFA_layout_mma = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, self.mma_tiler, self.sf_vec_size, sSFA_iter_layout + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + tCtSFA_mma = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout_mma) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB_layout_mma = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, self.mma_tiler, self.sf_vec_size, sSFB_iter_layout + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + tCtSFB_mma = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout_mma) + + # + # Partition for S2T copy of SFA/SFB + # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + # + # Persistent tile scheduling loop + # + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + MmasPerSfBuffer = 8 // self.sf_buffers_per_tile_k + sf_stride = 6 if self.sf_vec_size == 16 else 3 + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + tCtAcc = tCtAcc_base[(None, 0, 0, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer.reset() + peek_ab_full_status = cutlass.Boolean(1) + if is_leader_cta: + peek_ab_full_status = ab_consumer.try_wait() + + # Peek (try_wait) SF buffer full + sf_consumer.reset() + peek_sf_full_status = cutlass.Boolean(1) + if is_leader_cta: + peek_sf_full_status = sf_consumer.try_wait() + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + is_first_iteration = True + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + if is_leader_cta: + # Conditionally load SFA/SFB for MMA0/MMA1 depending on sf_vec_size + if 0 % MmasPerSfBuffer == 0: + sf_full = sf_consumer.wait_and_advance(peek_sf_full_status) + s2t_stage_coord = ( + None, + None, + None, + None, + sf_full.index, + ) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + sf_full.release() + peek_sf_full_status = cutlass.Boolean(1) + peek_sf_full_status = sf_consumer.try_wait() + + # Wait for A/B data to be ready(MMA0, MMA1, part of MMA2) + ab_full0 = ab_consumer.wait_and_advance(peek_ab_full_status) + + # peek for next stage (MMA2, MMA3, MMA4, part of MMA5) + peek_ab_full_status = cutlass.Boolean(1) + peek_ab_full_status = ab_consumer.try_wait() + + # delay the acc acquire to ublock tmem + if is_first_iteration: + acc_pipeline.producer_acquire(acc_producer_state) + is_first_iteration = False + + # MMA0 + k_block_coord_cur = (None, 0, 0, ab_full0.index) + k_block_coord_next = (None, 0, 0, ab_full0.index) + sf_kblock_coord = (None, None, 0 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # MMA1 + k_block_coord_cur = (None, 0, 3, ab_full0.index) + k_block_coord_next = (None, 0, 0, ab_full0.index) + sf_kblock_coord = (None, None, 1 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + # Conditionally load SFA/SFB for MMA2/MMA3 + if 2 % MmasPerSfBuffer == 0: + sf_full = sf_consumer.wait_and_advance(peek_sf_full_status) + s2t_stage_coord = ( + None, + None, + None, + None, + sf_full.index, + ) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + sf_full.release() + peek_sf_full_status = cutlass.Boolean(1) + peek_sf_full_status = sf_consumer.try_wait() + + # Wait for A/B data to be ready(MMA2, MMA3, MMA4, part of MMA5) + ab_full1 = ab_consumer.wait_and_advance(peek_ab_full_status) + + # peek for next stage (part of MMA5, MMA6, MMA7) + peek_ab_full_status = cutlass.Boolean(1) + peek_ab_full_status = ab_consumer.try_wait() + + # MMA2 + k_block_coord_cur = (None, 0, 6, ab_full0.index) + k_block_coord_next = (None, 0, 0, ab_full1.index) + sf_kblock_coord = (None, None, 2 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + # Release stage_ab_0 as it is no longer needed + ab_full0.release() + + # MMA3 + k_block_coord_cur = (None, 0, 1, ab_full1.index) + k_block_coord_next = (None, 0, 0, ab_full1.index) + sf_kblock_coord = (None, None, 3 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + # Conditionally load SFA/SFB for MMA4/MMA5 + if 4 % MmasPerSfBuffer == 0: + sf_full = sf_consumer.wait_and_advance(peek_sf_full_status) + s2t_stage_coord = ( + None, + None, + None, + None, + sf_full.index, + ) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + sf_full.release() + peek_sf_full_status = cutlass.Boolean(1) + peek_sf_full_status = sf_consumer.try_wait() + + # MMA4 + k_block_coord_cur = (None, 0, 4, ab_full1.index) + k_block_coord_next = (None, 0, 0, ab_full1.index) + sf_kblock_coord = (None, None, 4 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + # Wait for A/B data to be ready(part of MMA5, MMA6, MMA7) + ab_full2 = ab_consumer.wait_and_advance(peek_ab_full_status) + + # peek for next loop's first stage (MMA0, MMA1, part of MMA2) + peek_ab_full_status = cutlass.Boolean(1) + if k_tile + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() + + # MMA5 + k_block_coord_cur = (None, 0, 7, ab_full1.index) + k_block_coord_next = (None, 0, 0, ab_full2.index) + sf_kblock_coord = (None, None, 5 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + # Conditionally load SFA/SFB for MMA6/MMA7 + if 6 % MmasPerSfBuffer == 0: + sf_full = sf_consumer.wait_and_advance(peek_sf_full_status) + s2t_stage_coord = ( + None, + None, + None, + None, + sf_full.index, + ) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + sf_full.release() + peek_sf_full_status = cutlass.Boolean(1) + if k_tile + 1 < k_tile_cnt: + peek_sf_full_status = sf_consumer.try_wait() + + ab_full1.release() + + # MMA6 + k_block_coord_cur = (None, 0, 2, ab_full2.index) + k_block_coord_next = (None, 0, 0, ab_full2.index) + sf_kblock_coord = (None, None, 6 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + # MMA7 + k_block_coord_cur = (None, 0, 5, ab_full2.index) + k_block_coord_next = (None, 0, 0, ab_full2.index) + sf_kblock_coord = (None, None, 7 % MmasPerSfBuffer * sf_stride) + tiled_mma.set( + tcgen05.Field.SFA, tCtSFA_mma[sf_kblock_coord].iterator + ) + tiled_mma.set( + tcgen05.Field.SFB, tCtSFB_mma[sf_kblock_coord].iterator + ) + self.make_desc_and_call_mma( + tiled_mma, + tCtAcc, + sA[k_block_coord_cur], + sA[k_block_coord_next], + sB[k_block_coord_cur], + sB[k_block_coord_next], + tCtAcc, + ) + + ab_full2.release() + + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + sC = None + if cutlass.const_expr(self.use_tma_store): + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + if cutlass.const_expr(self.use_tma_store): + assert tma_atom_c is not None and sC is not None + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilogue_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + # Wrap epilogue_op with alpha scaling + alpha_epilogue_op = lambda x: epilogue_op(alpha_value * x) + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + # + # Pre-advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + num_tiles_executed = tile_sched.num_tiles_executed + if cutlass.const_expr(self.use_tma_store): + acc_consumer_state = utils.gemm.sm100.epilogue_tma_store( + self, + tidx, + warp_idx, + tma_atom_c, + tCtAcc_base, + sC, + tCgC, + epi_tile, + num_tiles_executed, + alpha_epilogue_op, + mma_tile_coord_mnl, + acc_consumer_state, + acc_pipeline, + c_pipeline, + ) + else: + acc_consumer_state = utils.gemm.sm100.epilogue( + self, + tidx, + tCtAcc_base, + tCgC, + epi_tile, + alpha_epilogue_op, + mma_tile_coord_mnl, + acc_consumer_state, + acc_pipeline, + tCcC_base=tCcC, + mC_mnl=mC_mnl, + ) + + if cutlass.const_expr(self.use_tma_store): + # Wait for C store complete + c_pipeline.producer_tail() + else: + # Synchronize before TMEM dealloc (done by the caller) + tmem_dealloc_barrier.arrive_and_wait() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + tmem.free(acc_tmem_ptr) + + cute.arch.mbarrier_init_fence() + + griddepcontrol_launch_dependents() + + @staticmethod + def make_desc_and_call_mma( + tiled_mma: cute.TiledMma, + d: cute.Tensor, + sA_cur: cute.Tensor, + sA_next: cute.Tensor, + sB_cur: cute.Tensor, + sB_next: cute.Tensor, + c: cute.Tensor, + ) -> None: + """Specialized GEMM for circular-buffered A/B from SMEM. + + Performs D <- A * B + C where A and B are described by circular SMEM + descriptors constructed from the (current, next) buffers. C and D may alias. + + Some tcgen05 MMAs require explicitly toggling an accumulate field outside of + this routine; the caller is responsible for that. + + All tensors must already be partitioned for the provided tiled MMA. + + For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread + election internally. Manual thread selection is not required in such cases. + + :param atom: MMA atom + :type atom: cute.MmaAtom + :param d: Destination tensor + :type d: cute.Tensor + :param sA_cur: Current shared memory tensor for operand A + :type sA_cur: cute.Tensor + :param sA_next: Next shared memory tensor for operand A, used for circular buffering + :type sA_next: cute.Tensor + :param sB_cur: Current shared memory tensor for operand B + :type sB_cur: cute.Tensor + :param sB_next: Next shared memory tensor for operand B, used for circular buffering + :type sB_next: cute.Tensor + :param c: Third source tensor + :type c: cute.Tensor + :return: None + :rtype: None + """ + a_desc = tcgen05.make_umma_smem_desc( + sA_cur.iterator, + sA_cur.layout, + "k" if tiled_mma.op.a_major_mode.name == "K" else "mn", + next_src=sA_next.iterator, + ) + b_desc = tcgen05.make_umma_smem_desc( + sB_cur.iterator, + sB_cur.layout, + "k" if tiled_mma.op.b_major_mode.name == "K" else "mn", + next_src=sB_next.iterator, + ) + + view_layout = cute.make_layout(1, stride=0) + a_tensor = cute.make_tensor(a_desc, view_layout) + b_tensor = cute.make_tensor(b_desc, view_layout) + return cute.mma_atom_call(tiled_mma, d, a_tensor, b_tensor, c) + + @staticmethod + def sm103_make_blockscaled_trivial_tiled_mma( + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + cta_group: tcgen05.CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: tcgen05.OperandSource = tcgen05.OperandSource.SMEM, + ) -> cute.TiledMma: + """Create a blockscaled trivial tiled MMA for SM103 (3xFP4), K fixed to 96. + + Returns a tcgen05 MMA configured for the given (M, N) tiler and CTA group. + + :param sf_dtype: Data type of the scale factor (typically 8-bit) + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param cta_group: The CTA group configuration + :type cta_group: tcgen05.CtaGroup + :param mma_tiler_mn: The MMA tiler dimensions (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param a_source: Source location for operand A (SMEM by default) + :type a_source: tcgen05.OperandSource + + :return: A tiled MMA atom configured for SM103 blockscaled operations + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + :raises ValueError: If the sf_vec_size is not supported. + """ + if sf_vec_size == 32: + mma_op = tcgen05.SM103MmaMXF4Op( + (*mma_tiler_mn, 96), + cta_group, + a_source, + ) + elif sf_vec_size == 16: + mma_op = tcgen05.SM103MmaMXF4NVF4Op( + sf_dtype, + (*mma_tiler_mn, 96), + cta_group, + a_source, + ) + else: + raise ValueError( + f"Unsupported sf_vec_size: {sf_vec_size}. Expected 16 or 32." + ) + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + # Utils + @staticmethod + def sm103_make_smem_layout_a( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + num_stages: int, + ) -> Union[cute.Layout, cute.ComposedLayout]: + """ + Create the SMEM layout for operand A using K_SW128 and Uint8. + + This function creates a SMEM layout for operand A using the make_smem_layout_atom function with K_SW128 kind and Uint8 element type. + + :param tiled_mma: The tiled MMA atom + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape (M, N, K) + :type mma_tiler_mnk: cute.Tile + :param num_stages: The number of stages + :type num_stages: int + + :return: SMEM layout for operand A + :rtype: cute.Layout + """ + is_k_major = tiled_mma.op.a_major_mode == tcgen05.OperandMajorMode.K + a_smem_layout_staged = tcgen05.tile_to_mma_shape( + tcgen05.make_smem_layout_atom( + tcgen05.SmemLayoutAtomKind.K_SW128, cutlass.Uint8 + ), + cute.append( + ( + ( + mma_tiler_mnk[0] + // cute.size(tiled_mma.thr_layout_vmnk.shape[0]), + 16, + ), + 1, + 8, + ), + num_stages, + ), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + ) + + return a_smem_layout_staged + + @staticmethod + def sm103_make_smem_layout_b( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + num_stages: int, + ) -> Union[cute.Layout, cute.ComposedLayout]: + """ + Create the SMEM layout for operand B using K_SW128 and Uint8. + + This function creates a SMEM layout for operand B using the make_smem_layout_atom function with K_SW128 kind and Uint8 element type. + + :param tiled_mma: The tiled MMA atom + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape (M, N, K) + :type mma_tiler_mnk: cute.Tile + :param num_stages: The number of stages + :type num_stages: int + + :return: SMEM layout for operand B + :rtype: cute.Layout + """ + is_k_major = tiled_mma.op.b_major_mode == tcgen05.OperandMajorMode.K + b_smem_layout_staged = tcgen05.tile_to_mma_shape( + tcgen05.make_smem_layout_atom( + tcgen05.SmemLayoutAtomKind.K_SW128, cutlass.Uint8 + ), + cute.append( + ((mma_tiler_mnk[1] // cute.size(tiled_mma.thr_id.shape), 16), 1, 8), + num_stages, + ), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + ) + return b_smem_layout_staged + + @dataclass(frozen=True) + class Sm103BlockScaledBasicChunk: + """ + Basic scale-factor atom layout decided by tcgen05 BlockScaled MMA Ops on SM103. + + Represents the fixed layout pattern for scale factors used by tcgen05 + BlockScaled MMA Ops on SM103. The layout is determined by the instruction + specification and is not configurable. + """ + + sf_vec_size: int + major_mode: tcgen05.OperandMajorMode = tcgen05.OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + if self.major_mode == tcgen05.OperandMajorMode.K: + atom_shape = ((8, 4, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 128, 4), (0, 1)) + else: + atom_shape = ((self.sf_vec_size, 4), (8, 4, 4)) # type: ignore[assignment] + atom_stride = ((0, 1), (16, 128, 4)) # type: ignore[assignment] + + object.__setattr__( + self, "_layout", cute.make_layout(shape=atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + return self._layout + + @staticmethod + def sm103_make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler: cute.Tile, + sf_vec_size: int, + num_stages: int, + ) -> cute.Layout: + """ + Make SMEM layout for SFA based on: + 1) Sm103BlockScaledBasicChunk, 2) MMA tiler, 3) sf_vec_size, 4) stages. + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler: The mma tiler shape + :type mma_tiler: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + mma_shape_mk = tiled_mma.partition_shape_A((mma_tiler[0], mma_tiler[2])) + sf_atom = Sm103BlockScaledPersistentDenseGemmKernel.Sm103BlockScaledBasicChunk( + sf_vec_size, tiled_mma.op.a_major_mode + ).layout + k_divisor = 4 if sf_vec_size == 16 else 2 + mma_sfa_tiler = ( + mma_shape_mk[0][0] * mma_shape_mk[1], + mma_shape_mk[0][1] * mma_shape_mk[2] // k_divisor, + ) + sfa_smem_atom_layout = cute.tiled_product( + sf_atom, + cute.make_layout( + cute.shape_div(mma_sfa_tiler, cute.product_each(sf_atom.shape)) + ), + ) + sfa_smem_layout_staged = cute.make_layout( + shape=cute.append(sfa_smem_atom_layout.shape, num_stages), + stride=cute.append( + sfa_smem_atom_layout.stride, + cute.size(cute.filter_zeros(sfa_smem_atom_layout)), + ), + ) + return sfa_smem_layout_staged + + @staticmethod + def sm103_make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler: cute.Tile, + sf_vec_size: int, + num_stages: int, + ) -> cute.Layout: + """ + Make SMEM layout for SFB based on the basic chunk, MMA tiler, sf_vec_size, stages. + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler: The mma tiler shape + :type mma_tiler: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFB + :rtype: cute.Layout + """ + sf_atom = Sm103BlockScaledPersistentDenseGemmKernel.Sm103BlockScaledBasicChunk( + sf_vec_size, tiled_mma.op.a_major_mode + ).layout + k_divisor = 4 if sf_vec_size == 16 else 2 + mma_sfb_tiler = (mma_tiler[1], mma_tiler[2] // k_divisor) + if mma_sfb_tiler[0] == 128: + sfb_smem_atom_layout = cute.tiled_product( + sf_atom, + cute.make_layout( + cute.shape_div(mma_sfb_tiler, cute.product_each(sf_atom.shape)) + ), + ) + else: + sf_k_major_atom256 = cute.make_layout( + shape=( + (32, 4, 2), + (sf_vec_size, 4), + ), + stride=( + (16, 4, mma_sfb_tiler[1] // sf_vec_size // 4 * 512), + (0, 1), + ), + ) + sfb_smem_atom_layout = cute.tiled_product( + sf_k_major_atom256, + cute.make_layout( + cute.shape_div( + mma_sfb_tiler, cute.product_each(sf_k_major_atom256.shape) + ) + ), + ) + + sfb_smem_layout_staged = cute.make_layout( + shape=cute.append(sfb_smem_atom_layout.shape, num_stages), + stride=cute.append( + sfb_smem_atom_layout.stride, + cute.size(cute.filter_zeros(sfb_smem_atom_layout)), + ), + ) + return sfb_smem_layout_staged + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + tCtSF_compact_copy = cute.make_tensor( + tCtSF_compact.iterator, + cute.append( + cute.append(tCtSF_compact[(None, 0, 0)].layout, cute.make_layout((1))), + cute.make_layout(1), + ), + ) + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact_copy) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int, int]: + """Computes the number of stages for A/B and SF operands based on heuristics. + + SM103 requires separate stage counts for AB and SF pipelines. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tiler. + :type mma_tiler: tuple[int, int, int] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, SF stages) + :rtype: tuple[int, int, int] + """ + # ACC stages - same as SM100 dense blockscaled gemm + num_acc_stage = 1 if mma_tiler[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB + a_smem_layout_stage_one = ( + Sm103BlockScaledPersistentDenseGemmKernel.sm103_make_smem_layout_a( + tiled_mma, + mma_tiler, + 1, + ) + ) + b_smem_layout_staged_one = ( + Sm103BlockScaledPersistentDenseGemmKernel.sm103_make_smem_layout_b( + tiled_mma, + mma_tiler, + 1, + ) + ) + sfa_smem_layout_staged_one = ( + Sm103BlockScaledPersistentDenseGemmKernel.sm103_make_smem_layout_sfa( + tiled_mma, + mma_tiler, + sf_vec_size, + 1, + ) + ) + sfb_smem_layout_staged_one = ( + Sm103BlockScaledPersistentDenseGemmKernel.sm103_make_smem_layout_sfb( + tiled_mma, + mma_tiler, + sf_vec_size, + 1, + ) + ) + + c_smem_layout_staged_one = sm103_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + ab_bytes_per_stage = cute.size_in_bytes( + cutlass.Uint8, a_smem_layout_stage_one + ) + cute.size_in_bytes(cutlass.Uint8, b_smem_layout_staged_one) + sf_bytes_per_stage = cute.size_in_bytes( + sf_dtype, sfa_smem_layout_staged_one + ) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + + mbar_helpers_bytes = 1024 + + num_ab_stage = ( + smem_capacity // occupancy + - (mbar_helpers_bytes + sf_bytes_per_stage + c_bytes) + ) // ab_bytes_per_stage + + num_sf_stage = ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * mbar_helpers_bytes + - occupancy * c_bytes + ) // (occupancy * sf_bytes_per_stage) + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + # xinyu TODO: not sure if aligned with c++ + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * sf_bytes_per_stage * num_sf_stage + - occupancy * mbar_helpers_bytes + - occupancy * c_bytes + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_sf_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype != cutlass.Float4E2M1FN: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [128, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + _is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not _is_power_of_2(cluster_shape_mn[0]) + or not _is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_alignment( + dtype, is_mode0_major, tensor_shape, alignment_bytes + ): + """Check if tensor satisfies the required byte alignment. + + :param dtype: Data type of the tensor + :param is_mode0_major: Whether mode 0 is the major (contiguous) mode + :param tensor_shape: Shape of the tensor (mode0, mode1, batch) + :param alignment_bytes: Required alignment in bytes (e.g., 16 or 32) + :return: True if alignment is satisfied + """ + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + # Calculate number of contiguous elements needed for alignment + # alignment_bytes * 8 (bits per byte) / dtype.width (bits per element) + num_contiguous_elements = alignment_bytes * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + # Check A/B tensors for 16B alignment + # Check C tensor for 32B alignment + if ( + not check_contigous_alignment(ab_dtype, a_major == "m", (m, k, l), 16) + or not check_contigous_alignment(ab_dtype, b_major == "n", (n, k, l), 16) + or not check_contigous_alignment(c_dtype, c_major == "m", (m, n, l), 32) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + use_tma_store: bool, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm103BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm103BlockScaledPersistentDenseGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm103BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm103BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + # Helper function for append and coalesce layout + @staticmethod + def append_coalesce_layout(layout): + # coalesce is like: cutlass/python/pycute/layout.py:coalesce + part1 = cute.coalesce(cute.append(layout[0][0], layout[1])) + part2 = cute.coalesce(cute.append(layout[0][1], layout[2])) + result = cute.append(part1, part2) + result = cute.append(result, layout[3]) + result = cute.append(result, layout[4]) + result = cute.append(result, layout[5]) + return result + + @staticmethod + def adapt_layout_for_tma_ab(composed_layout): + # input: S<3,4,3> o 0 o ((128,16),1,8,3):((128,1),0,16,16384) + # output: S<3,4,3> o 0 o (128,(128,3)):(128,(1,16384)) + # for ctaValueMap: (128,384):(1@0,1@1) + layout = composed_layout.outer + part1 = cute.coalesce(cute.append(layout[0][0], layout[1])) + part2 = cute.coalesce(cute.append(layout[0][1], layout[2])) + part3 = cute.append(part2, layout[3]) + result = cute.append(part1, part3) + return cute.make_composed_layout( + composed_layout.inner, composed_layout.offset, result + ) + + @staticmethod + def adapt_layout_for_tma_sf(layout): + # TODO: need ethan check this + # input: (((8,4,4),(16,4)),1,3):(((16,128,4),(0,1)),0,512) + # output: ((32,4),(16,4,3)):((16,4),(0,1,512)) + # for ctaValueMap: ((8,4,4),(16,4,3)):((1@0@0@0,1@1@0@0,1@2@0@0),(1@0@0@1,1@1@0@1,1@1@1)) + part1 = cute.coalesce(cute.append(layout[0][0], layout[1])) + part2 = cute.coalesce(cute.append(layout[0][1], layout[2])) + result = cute.append(cute.group_modes(part1, 0, cute.rank(part1)), part2) + return result + + @cute.jit + def wrapper( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sf_m: cutlass.Int64, + sf_n: cutlass.Int64, + sf_k: cutlass.Int64, + l: cutlass.Constexpr, + a_sf_ptr: cute.Pointer, + b_sf_ptr: cute.Pointer, + alpha_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + current_stream, + swap_ab: cutlass.Constexpr = False, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the wrapped GEMM kernel with dynamically shaped tensors. + + Uses TVM-FFI for efficient tensor passing: A, B, C, and alpha are passed + as cute.Tensor directly (torch tensors at runtime via TVM-FFI's C-level + dlpack). Scale factor tensors remain as pointers (complex 6D layout). + + Args: + mA (cute.Tensor): Input A, shape (m, k_packed), Uint8 (FP4 packed). + mB (cute.Tensor): Input B, shape (n, k_packed), Uint8 (FP4 packed). + mC (cute.Tensor): Output C, shape (m, n). + sf_m/sf_n/sf_k: Scale factor dimensions. + l: Batch dimension. + a_sf_ptr/b_sf_ptr: Scale factor pointers (6D layout). + alpha_tensor: Alpha scaling factor, shape (1,), float32. + max_active_clusters: Max active clusters. + current_stream: CUDA stream (TVM-FFI fake stream). + swap_ab: Whether A/B are swapped (controls C layout). + epilogue_op: Elementwise epilogue function. + """ + # A/B come in as Uint8 (FP4 packed as uint8 in torch). Recast to FP4. + m = cute.size(mA, mode=[0]) + k_packed = cute.size(mA, mode=[1]) + n = cute.size(mB, mode=[0]) + k = k_packed * 2 # 2 FP4 values per uint8 byte + + # Recast Uint8 → Float4E2M1FN and reshape to (m, k, l) K-major + a_fp4_ptr = cute.recast_ptr(mA.iterator, dtype=cutlass.Float4E2M1FN) + a_tensor = cute.make_tensor( + a_fp4_ptr, + layout=cute.make_ordered_layout((m, k, l), order=(1, 0, 2)), + ) + b_fp4_ptr = cute.recast_ptr(mB.iterator, dtype=cutlass.Float4E2M1FN) + b_tensor = cute.make_tensor( + b_fp4_ptr, + layout=cute.make_ordered_layout( + (n, k, l), + order=(1, 0, 2), + ), + ) + # C: swap_ab is constexpr, determines layout at compile time + if cutlass.const_expr(swap_ab): + c_tensor = cute.make_tensor( + mC.iterator, + layout=cute.make_ordered_layout((m, n, l), order=(0, 1, 2)), + ) + else: + c_tensor = cute.make_tensor( + mC.iterator, + layout=cute.make_ordered_layout((m, n, l), order=(1, 0, 2)), + ) + # Scale factor tensors: 6D BlockScaledBasicChunk layout from pointers + sfa_tensor = cute.make_tensor( + a_sf_ptr, + layout=cute.make_ordered_layout( + (32, 4, sf_m, 4, sf_k, l), + order=(2, 1, 4, 0, 3, 5), + ), + ) + sfb_tensor = cute.make_tensor( + b_sf_ptr, + layout=cute.make_ordered_layout( + (32, 4, sf_n, 4, sf_k, l), + order=(2, 1, 4, 0, 3, 5), + ), + ) + + self( + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + alpha_tensor, + max_active_clusters, + current_stream, + epilogue_op, + ) diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index cc85f6126a..7753a3efbe 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -29,6 +29,13 @@ def _test_mm_fp4( pytest.skip("Skipping test for trtllm fp4 with float16") if compute_capability[0] in [11, 12]: pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.") + if backend == "cute-dsl": + if not use_nvfp4: + pytest.skip("cute_dsl backend only supports nvfp4") + if not use_128x4_sf_layout: + pytest.skip("cute_dsl backend only supports 128x4 SF layout") + if compute_capability[0] not in [10]: + pytest.skip("cute_dsl backend only supports SM100/SM103 GPUs.") if not use_128x4_sf_layout and backend != "trtllm": pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False") if not use_nvfp4 and backend not in ["cudnn", "auto"]: @@ -99,7 +106,7 @@ def _test_mm_fp4( @pytest.mark.parametrize("n", [128, 256, 512]) @pytest.mark.parametrize("k", [128, 256, 512]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass"]) +@pytest.mark.parametrize("backend", ["trtllm", "cudnn", "cutlass", "cute-dsl"]) @pytest.mark.parametrize("use_128x4_sf_layout", [False, True]) @pytest.mark.parametrize("auto_tuning", [False, True]) @pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"])