diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 0c2bc13f6b..2fdc9013ac 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -1646,18 +1646,40 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): cumsum_s_qo = torch.sum(actual_seq_lens_q) cumsum_s_kv = torch.sum(actual_seq_lens_kv) - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype + + # Front-padding for cute-dsl varlen kernel: the persistent varlen kernel + # applies a negative pointer offset (-max_s * H * D), so there must be + # valid GPU memory before the data start. + front_pad_q = s_qo if "cute-dsl" in backends else 0 + front_pad_kv = s_kv if "cute-dsl" in backends else 0 + + q_full = torch.randn( + front_pad_q + cumsum_s_qo, + num_qo_heads, + head_dim_qk, + device=device, + dtype=q_init_dtype, ) + q = q_full[front_pad_q:] if args.verbose >= 2: print(f"[VVERBOSE] {q.shape = }") - k = torch.randn( - cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype + k_full = torch.randn( + front_pad_kv + cumsum_s_kv, + num_kv_heads, + head_dim_qk, + device=device, + dtype=kv_init_dtype, ) - v = torch.randn( - cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype + k = k_full[front_pad_kv:] + v_full = torch.randn( + front_pad_kv + cumsum_s_kv, + num_kv_heads, + head_dim_vo, + device=device, + dtype=kv_init_dtype, ) + v = v_full[front_pad_kv:] block_tables = None @@ -1815,14 +1837,18 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): v = (v / v_scale).to(kv_dtype) trtllm_out = None - if "trtllm-native" in backends: - trtllm_out = torch.empty( - q.shape[0], + if "trtllm-native" in backends or "cute-dsl" in backends: + # cute-dsl varlen kernel uses negative pointer offsets on output, + # so front-pad like Q/K/V. + out_pad = front_pad_q if "cute-dsl" in backends else 0 + trtllm_out_full = torch.empty( + out_pad + q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=out_dtype, ) + trtllm_out = trtllm_out_full[out_pad:] def run_backend_wrapper( backend, @@ -1843,6 +1869,31 @@ def run_backend_wrapper( ): if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]: return backend_wrappers[backend].run_return_lse(q, k, v)[0] + elif backend == "cute-dsl": + _q_scale = q_scale if q_scale is not None else 1.0 + _k_scale = k_scale if k_scale is not None else 1.0 + _v_scale = v_scale if v_scale is not None else 1.0 + return flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=workspace_buffer, + seq_lens=actual_seq_lens_kv_device, + max_q_len=s_qo, + max_kv_len=s_kv, + bmm1_scale=_q_scale * _k_scale * scale, + bmm2_scale=_v_scale, + o_sf_scale=-1, + batch_size=batch_size, + window_left=-1, + cum_seq_lens_q=qo_indptr, + cum_seq_lens_kv=kv_indptr, + enable_pdl=False, + is_causal=causal, + return_lse=True, + out=trtllm_out, + backend="cute-dsl", + )[0] elif backend == "cudnn": # cuDNN uses wrapper API return backend_wrappers[backend].run(q, k, v) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index ed81f364fe..2ff90e5cc1 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -322,8 +322,22 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": ["fa2", "cudnn", "cudnn-native"], "8.9": ["fa2", "cudnn", "cudnn-native"], "9.0": ["fa2", "fa3", "cudnn", "cudnn-native"], - "10.0": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"], - "10.3": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"], + "10.0": [ + "fa2", + "cudnn", + "cudnn-native", + "cutlass", + "cute-dsl", + "trtllm-native", + ], + "10.3": [ + "fa2", + "cudnn", + "cudnn-native", + "cutlass", + "cute-dsl", + "trtllm-native", + ], "12.0": ["fa2", "cudnn", "cudnn-native"], "12.1": ["fa2", "cudnn", "cudnn-native"], }, diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 6794b91b4e..ef321bf281 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -145,6 +145,8 @@ class ArtifactPath: CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" # For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" + DSL_FMHA: str = "c770c91cb0d991b7828fc85d2253a62f0d356b6c/fmha/cute-dsl/" + DSL_FMHA_ARCHS: tuple[str, ...] = ("sm_100a", "sm_103a", "sm_110a") class CheckSumHash: @@ -164,11 +166,32 @@ class CheckSumHash: TRTLLM_GEN_GEMM: str = ( "64b7114a429ea153528dd4d4b0299363d7320964789eb5efaefec66f301523c7" ) + # SHA256 of the checksums.txt manifest file per cpu-arch/sm-arch, + # NOT hashes of individual kernel .so files. + DSL_FMHA_CHECKSUMS: dict[str, dict[str, str]] = { + "x86_64": { + "sm_100a": "9533536698cdc256d897fffb3114de317076654ff8630ff283d850cc3dc96d86", + "sm_103a": "927e1954f1d45b0ee876f139084e4facdfcc87e86f4d30cb92d5c33698d4c2d6", + "sm_110a": "277b1dceaab2081e3def37cf997280a3f2c3ac515d22b80be141253c0278b8b5", + }, + "aarch64": { + "sm_100a": "b48ed0bcc9bad4afd33e0784c8c9eb9e13e782afe197816b1d0747b11759493e", + "sm_103a": "bace619a560f3ce52ad6ba105fffb8ea8629fe57885a90892c9e15a7122467e1", + "sm_110a": "d8369bcfa443bfd791cd014e3b030d378f00a975db8278eebd5b2fb529e3257d", + }, + } map_checksums: dict[str, str] = { safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA, safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM, safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM, safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM, + **{ + safe_urljoin( + ArtifactPath.DSL_FMHA, f"{cpu_arch}/{sm_arch}/checksums.txt" + ): sha + for cpu_arch, sm_checksums in DSL_FMHA_CHECKSUMS.items() + for sm_arch, sha in sm_checksums.items() + }, } @@ -191,14 +214,30 @@ def get_checksums(subdirs): return checksums +def _get_host_cpu_arch() -> str: + """Return CPU architecture string matching artifactory layout.""" + import platform + + machine = platform.machine() + if machine in ("aarch64", "arm64"): + return "aarch64" + return "x86_64" + + def get_subdir_file_list() -> Generator[tuple[str, str], None, None]: base = FLASHINFER_CUBINS_REPOSITORY + cpu_arch = _get_host_cpu_arch() cubin_dirs = [ ArtifactPath.TRTLLM_GEN_FMHA, ArtifactPath.TRTLLM_GEN_BMM, ArtifactPath.TRTLLM_GEN_GEMM, ArtifactPath.DEEPGEMM, + # DSL FMHA: per cpu-arch and sm-arch subdirectories + *( + safe_urljoin(ArtifactPath.DSL_FMHA, f"{cpu_arch}/{arch}/") + for arch in ArtifactPath.DSL_FMHA_ARCHS + ), ] # Get checksums of all files diff --git a/flashinfer/attention/__init__.py b/flashinfer/attention/__init__.py new file mode 100644 index 0000000000..ef4df2e6b8 --- /dev/null +++ b/flashinfer/attention/__init__.py @@ -0,0 +1,23 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from ._core import * # noqa: F401,F403 +from ._core import BatchAttention, BatchAttentionWithAttentionSinkWrapper + +__all__ = [ + "BatchAttention", + "BatchAttentionWithAttentionSinkWrapper", +] diff --git a/flashinfer/attention.py b/flashinfer/attention/_core.py similarity index 97% rename from flashinfer/attention.py rename to flashinfer/attention/_core.py index c4bc4f27dc..7d9c377f33 100644 --- a/flashinfer/attention.py +++ b/flashinfer/attention/_core.py @@ -20,9 +20,9 @@ import torch -from .api_logging import flashinfer_api -from .jit import gen_batch_attention_module -from .utils import ( +from ..api_logging import flashinfer_api +from ..jit import gen_batch_attention_module +from ..utils import ( MaskMode, PosEncodingMode, TensorLayout, @@ -30,9 +30,9 @@ _unpack_paged_kv_cache, determine_attention_backend, ) -from .prefill import BatchPrefillWithPagedKVCacheWrapper -from .jit.attention.variants import attention_sink_decl -from .jit.utils import filename_safe_dtype_map +from ..prefill import BatchPrefillWithPagedKVCacheWrapper +from ..jit.attention.variants import attention_sink_decl +from ..jit.utils import filename_safe_dtype_map @functools.cache diff --git a/flashinfer/attention/cute_dsl/__init__.py b/flashinfer/attention/cute_dsl/__init__.py new file mode 100644 index 0000000000..3e02962790 --- /dev/null +++ b/flashinfer/attention/cute_dsl/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2026 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL Attention Kernels (Cubin Distribution) +================================================ + +Pre-compiled FMHA kernels loaded via ExternalBinaryModule. +""" + +from flashinfer.cute_dsl.utils import is_cute_dsl_available + +if is_cute_dsl_available(): + from .fmha import ( + get_cute_dsl_fmha_kernel, + cute_dsl_fmha_ragged_prefill, + ) + + __all__ = [ + "is_cute_dsl_available", + "get_cute_dsl_fmha_kernel", + "cute_dsl_fmha_ragged_prefill", + ] +else: + __all__ = [ + "is_cute_dsl_available", + ] diff --git a/flashinfer/attention/cute_dsl/fmha.py b/flashinfer/attention/cute_dsl/fmha.py new file mode 100644 index 0000000000..6da175b95a --- /dev/null +++ b/flashinfer/attention/cute_dsl/fmha.py @@ -0,0 +1,563 @@ +# Copyright (c) 2026 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL FMHA Kernel Loader (Cubin Distribution) +================================================= + +Loads pre-compiled FMHA kernel .so artifacts via ExternalBinaryModule. +The .so files are compiled offline from the proprietary DSL kernel source +and distributed through the cubin_publishing pipeline. + +Runtime flow: + 1. get_artifact() downloads .so from artifactory (or local cache) + 2. ExternalBinaryModule loads .so and extracts callable kernel + 3. cute_dsl_fmha_ragged_prefill() wraps the kernel with a PyTorch-friendly API +""" + +import functools +import logging +import math +import os +from typing import Optional + +import cutlass +import cutlass.cute as cute +import torch +from cuda.bindings import driver as cuda_driver +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Float32, Int32 + +logger = logging.getLogger("flashinfer.attention.cute_dsl.fmha") + + +# ============================================================================= +# Artifact configuration +# ============================================================================= + +from flashinfer.artifacts import ArtifactPath + +DSL_FMHA_ARTIFACT_PATH = ArtifactPath.DSL_FMHA + +from flashinfer.artifacts import _get_host_cpu_arch as _get_cpu_arch + + +def _get_gpu_arch(device: torch.device) -> str: + """Get the GPU architecture string for *device* (e.g. 'sm_100a'). + + Uses the same normalization as CompilationContext to match the arch + subdirectory names used by the cubin publishing pipeline. + """ + major, minor = torch.cuda.get_device_capability(device) + from flashinfer.compilation_context import CompilationContext + + _, minor_str = CompilationContext._normalize_cuda_arch(major, minor) + return f"sm_{major}{minor_str}" + + +# Cache: arch -> {variant_name: sha256} parsed from downloaded checksums.txt +_checksums_cache: dict[str, dict[str, str]] = {} + + +def _get_checksums(gpu_arch: str) -> dict[str, str]: + """Download and parse checksums.txt for the given arch. + + The checksums.txt is generated by compile_dsl_fmha.py and published + alongside the .so files at //checksums.txt. + Format per line: "sha256 filename.so" + """ + if gpu_arch in _checksums_cache: + return _checksums_cache[gpu_arch] + + from flashinfer.jit.cubin_loader import download_file, safe_urljoin + from flashinfer.jit.cubin_loader import FLASHINFER_CUBINS_REPOSITORY + from flashinfer.jit.env import FLASHINFER_CUBIN_DIR + + cpu_arch = _get_cpu_arch() + checksums_rel = os.path.join( + DSL_FMHA_ARTIFACT_PATH, cpu_arch, gpu_arch, "checksums.txt" + ) + local_path = FLASHINFER_CUBIN_DIR / checksums_rel + + if not local_path.exists(): + os.makedirs(local_path.parent, exist_ok=True) + uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, checksums_rel) + logger.info(f"Fetching checksums for {gpu_arch} from {uri}") + download_file(uri, str(local_path)) + + checksums = {} + with open(local_path) as f: + for line in f: + line = line.strip() + if not line: + continue + sha256, filename = line.split(None, 1) + variant_name = filename.rsplit(".", 1)[0] # strip .so extension + checksums[variant_name] = sha256 + + _checksums_cache[gpu_arch] = checksums + return checksums + + +def _dtype_to_str(dtype: torch.dtype) -> str: + """Convert torch dtype to short string matching compile_cute_dsl_fmha.py naming.""" + return { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float8_e4m3fn: "e4m3", + }[dtype] + + +def _get_variant_name( + in_dtype: torch.dtype, + out_dtype: torch.dtype, + head_dim: int, + is_causal: bool, + is_persistent: bool = True, + varlen: bool = False, + with_lse: bool = False, + enable_skip_softmax: bool = False, + enable_tvm_ffi: bool = False, +) -> str: + """Generate the variant name matching compile_cute_dsl_fmha.py naming convention.""" + in_str = _dtype_to_str(in_dtype) + out_str = _dtype_to_str(out_dtype) + # Only include out_dtype in name when it differs from in_dtype (mixed precision) + dtype_str = f"{in_str}_{out_str}" if in_dtype != out_dtype else in_str + causal_str = "causal" if is_causal else "nocausal" + persist_str = "persistent" if is_persistent else "nonpersistent" + varlen_str = "_varlen" if varlen else "" + lse_str = "_lse" if with_lse else "" + skip_str = "_skipsm" if enable_skip_softmax else "" + ffi_str = "_tvmffi" if enable_tvm_ffi else "" + return f"cute_dsl_fmha_{dtype_str}_h{head_dim}_{causal_str}_{persist_str}{varlen_str}{lse_str}{skip_str}{ffi_str}" + + +# ============================================================================= +# Loading: ExternalBinaryModule path +# ============================================================================= + + +def _load_from_artifact(variant_name: str, gpu_arch: str, enable_tvm_ffi: bool = False): + """Download .so from artifactory and load via ExternalBinaryModule. + + This is the production path used when cubins are published. + + Parameters + ---------- + variant_name : str + The kernel variant name (matches function_prefix used during export). + gpu_arch : str + GPU architecture string (e.g. 'sm_100a'). + enable_tvm_ffi : bool + If False (default), load with CuTe native ABI. + If True, load with TVM-FFI ABI (TODO: compile-side support pending). + """ + from flashinfer.jit.cubin_loader import get_artifact + from flashinfer.jit.env import FLASHINFER_CUBIN_DIR + + so_filename = f"{variant_name}.so" + cpu_arch = _get_cpu_arch() + artifact_path = os.path.join( + DSL_FMHA_ARTIFACT_PATH, cpu_arch, gpu_arch, so_filename + ) + + checksums = _get_checksums(gpu_arch) + sha256 = checksums.get(variant_name, "") + if not sha256: + raise RuntimeError( + f"No checksum found for DSL FMHA variant '{variant_name}' on {gpu_arch}. " + f"Available variants: {list(checksums.keys())}" + ) + + # get_artifact downloads to FLASHINFER_CUBIN_DIR / artifact_path and verifies sha256 + data = get_artifact(artifact_path, sha256) + if not data: + raise RuntimeError(f"Failed to download DSL FMHA artifact: {artifact_path}") + + local_path = FLASHINFER_CUBIN_DIR / artifact_path + module = cute.runtime.load_module(str(local_path), enable_tvm_ffi=enable_tvm_ffi) + return getattr(module, variant_name) + + +def _load_from_local(variant_name: str, local_dir: str, enable_tvm_ffi: bool = False): + """Load .so or .o from a local directory (for development/testing). + + Set FLASHINFER_DSL_FMHA_LOCAL_DIR to use this path. + + Parameters + ---------- + variant_name : str + The kernel variant name (matches function_prefix used during export). + local_dir : str + Directory containing the compiled .so/.o files. + enable_tvm_ffi : bool + If False (default), load with CuTe native ABI. + If True, load with TVM-FFI ABI (TODO: compile-side support pending). + """ + # Try .so first, then .o + so_path = os.path.join(local_dir, f"{variant_name}.so") + o_path = os.path.join(local_dir, f"{variant_name}.o") + if os.path.exists(so_path): + load_path = so_path + elif os.path.exists(o_path): + load_path = o_path + else: + raise FileNotFoundError( + f"DSL FMHA .so/.o not found at {local_dir}/{variant_name}.[so|o]. " + f"Run compile_cute_dsl_fmha.py to generate it." + ) + + module = cute.runtime.load_module(load_path, enable_tvm_ffi=enable_tvm_ffi) + return getattr(module, variant_name) + + +@functools.cache +def get_cute_dsl_fmha_kernel( + gpu_arch: str, + in_dtype: torch.dtype, + out_dtype: torch.dtype, + head_dim: int, + is_causal: bool, + is_persistent: bool = True, + enable_tvm_ffi: bool = False, + varlen: bool = False, + with_lse: bool = False, + enable_skip_softmax: bool = False, +): + """Get a compiled DSL FMHA kernel function. + + Checks local directory first (FLASHINFER_DSL_FMHA_LOCAL_DIR env var), + then falls back to artifact download. + + Parameters + ---------- + gpu_arch : str + GPU architecture string (e.g. 'sm_100a'), used both for selecting the + correct artifact and as part of the cache key. + in_dtype : torch.dtype + Input data type (torch.float16, torch.bfloat16, or torch.float8_e4m3fn). + out_dtype : torch.dtype + Output data type. Same as in_dtype for non-mixed precision. + head_dim : int + Head dimension (e.g., 64, 128, 192). Note: 192 only supports FP8. + is_causal : bool + Whether to use causal masking. + is_persistent : bool + Whether to use persistent kernel mode. + enable_tvm_ffi : bool + If False (default), load with CuTe native ABI — kernel accepts + cute Pointer/Tensor args (same calling convention as JIT mode). + If True, load with TVM-FFI ABI — Pointer args accept data_ptr(), + Tensor args accept torch.Tensor directly, stream uses env stream. + enable_skip_softmax : bool + If True, load kernel compiled with skip-softmax support. + + Returns + ------- + callable + The compiled kernel function. + """ + variant_name = _get_variant_name( + in_dtype, + out_dtype, + head_dim, + is_causal, + is_persistent, + varlen, + with_lse, + enable_skip_softmax, + enable_tvm_ffi, + ) + + # Check for local .so directory (development mode) + local_dir = os.environ.get("FLASHINFER_DSL_FMHA_LOCAL_DIR") + if local_dir: + logger.info( + f"Loading DSL FMHA kernel from local dir: {local_dir} (tvm_ffi={enable_tvm_ffi})" + ) + return _load_from_local(variant_name, local_dir, enable_tvm_ffi=enable_tvm_ffi) + + # Production path: download from artifactory + logger.info( + f"Loading DSL FMHA kernel variant: {variant_name} (tvm_ffi={enable_tvm_ffi})" + ) + return _load_from_artifact(variant_name, gpu_arch, enable_tvm_ffi=enable_tvm_ffi) + + +# ============================================================================= +# PyTorch API wrapper +# ============================================================================= + + +def cute_dsl_fmha_ragged_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + is_causal: bool = False, + sm_scale: Optional[float] = None, + window_left: int = -1, + window_right: int = -1, + lse: Optional[torch.Tensor] = None, + scale_q: float = 1.0, + scale_k: float = 1.0, + scale_v: float = 1.0, + scale_o: float = 1.0, + enable_tvm_ffi: bool = True, + max_qo_len: Optional[int] = None, + max_kv_len: Optional[int] = None, + kernel_fn=None, + skip_softmax_threshold_scale_factor: Optional[float] = None, +) -> None: + """Run DSL FMHA prefill kernel on ragged (variable-length) tensors. + + Note: The DSL FMHA kernel only supports per-tensor scalar scales, not + per-head scale tensors. + + **Front-padding requirement** (TODO: will be removed in the next MR): + The DSL kernel applies a negative pointer offset + (``-max_seq_len * H * D`` elements) internally. Callers must + allocate ``max_seq_len + total_tokens`` rows and pass the slice starting + at ``[max_seq_len:]`` as q/k/v/o so that the preceding memory is valid + GPU memory. For example:: + + q_full = torch.empty(max_s_q + total_q, H_q, D, ...) + q = q_full[max_s_q:] # pass this to the kernel + # (same for k, v, o with max_s_k / max_s_q respectively) + + Parameters + ---------- + q : torch.Tensor + Query tensor, shape (total_q_tokens, H_q, D). + Must have ``max_qo_len`` rows of valid GPU memory before index 0. + k : torch.Tensor + Key tensor, shape (total_kv_tokens, H_k, D). + Must have ``max_kv_len`` rows of valid GPU memory before index 0. + v : torch.Tensor + Value tensor, shape (total_kv_tokens, H_k, D_v). + Must have ``max_kv_len`` rows of valid GPU memory before index 0. + o : torch.Tensor + Output tensor, shape (total_q_tokens, H_q, D_v). Modified in-place. + Must have ``max_qo_len`` rows of valid GPU memory before index 0. + qo_indptr : torch.Tensor + Cumulative sequence lengths for Q/O, shape (batch_size + 1,). + Same as cum_seqlen_q in DSL FMHA kernel. + kv_indptr : torch.Tensor + Cumulative sequence lengths for K/V, shape (batch_size + 1,). + Same as cum_seqlen_k in DSL FMHA kernel. + is_causal : bool + Whether to apply causal masking. + sm_scale : float, optional + Softmax scale factor. Defaults to 1/sqrt(D). + window_left : int + Left sliding window size. -1 means no window. + window_right : int + Right sliding window size. -1 means no window. 0 for causal. + lse : torch.Tensor, optional + Log-sum-exp output tensor. None to skip. + scale_q : float + Per-tensor scale for query (FP8 calibration). Default 1.0. + scale_k : float + Per-tensor scale for key (FP8 calibration). Default 1.0. + scale_v : float + Per-tensor scale for value (FP8 calibration). Default 1.0. + scale_o : float + Per-tensor scale for output (FP8 calibration). Default 1.0. + enable_tvm_ffi : bool + If True, use TVM-FFI ABI (pass data_ptr() for Pointer args, torch.Tensor + for Tensor args, no explicit stream). Default False (CuTe native ABI). + max_qo_len : int, optional + Maximum query sequence length. Computed from qo_indptr if not provided. + Pass this from plan() to avoid D2H copy during CUDA graph capture. + max_kv_len : int, optional + Maximum KV sequence length. Computed from kv_indptr if not provided. + skip_softmax_threshold_scale_factor : float, optional + Threshold scale factor for skip-softmax sparsity (https://arxiv.org/abs/2512.12087). + The actual threshold = scale_factor / max_kv_len, then converted to log2 domain. + None or 0 disables skip-softmax. + """ + total_q, H_q, D = q.shape + total_kv, H_k, _ = k.shape + D_v = v.shape[-1] + + batch_size = len(qo_indptr) - 1 + + use_skip_softmax = ( + skip_softmax_threshold_scale_factor is not None + and skip_softmax_threshold_scale_factor > 0 + ) + + if kernel_fn is None: + gpu_arch = _get_gpu_arch(q.device) + kernel_fn = get_cute_dsl_fmha_kernel( + gpu_arch, + q.dtype, + o.dtype, + D, + is_causal, + is_persistent=False, # varlen always uses non-persistent + varlen=True, + enable_tvm_ffi=enable_tvm_ffi, + with_lse=lse is not None, + enable_skip_softmax=use_skip_softmax, + ) + + # Compute scale factors + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(D) + log2_e = math.log2(math.exp(1.0)) + scale_softmax = scale_q * scale_k * sm_scale + scale_softmax_log2 = scale_softmax * log2_e + scale_output = scale_v / scale_o + + # Max seq lengths for problem_size (prefer pre-computed values for CUDA graph compat) + if max_qo_len is None: + max_s_q = int((qo_indptr.cpu()[1:] - qo_indptr.cpu()[:-1]).max().item()) + else: + max_s_q = max_qo_len + if max_kv_len is None: + max_s_k = int((kv_indptr.cpu()[1:] - kv_indptr.cpu()[:-1]).max().item()) + else: + max_s_k = max_kv_len + + # problem_size: (B, max_s_q, s_lse, max_s_k, H_q, H_k, D, D_v) + s_lse = total_q # for variable length, s_lse = total tokens + problem_size = (batch_size, max_s_q, s_lse, max_s_k, H_q, H_k, D, D_v) + + # Skip-softmax threshold (convert scale_factor to log2 domain) + skip_threshold_log2 = None + if ( + skip_softmax_threshold_scale_factor is not None + and skip_softmax_threshold_scale_factor > 0 + ): + threshold = skip_softmax_threshold_scale_factor / max_s_k + skip_threshold_log2 = Float32(math.log2(threshold)) + + # Window sizes + ws_left = None if window_left == -1 else Int32(window_left) + ws_right = None if window_right == -1 else Int32(window_right) + if is_causal and ws_right is None: + ws_right = Int32(0) + + if enable_tvm_ffi: + # TVM-FFI: Pointer args accept data_ptr(), Tensor args accept torch.Tensor, + # no explicit stream (env stream). + # Kernel expects 4D pointers; unsqueeze to (1, total, H, D). + q_4d = q.unsqueeze(0) + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + o_4d = o.unsqueeze(0) + + kernel_fn( + q_4d.data_ptr(), + k_4d.data_ptr(), + v_4d.data_ptr(), + o_4d.data_ptr(), + problem_size, + qo_indptr.to(torch.int32), # cum_seqlen_q: Tensor arg + kv_indptr.to(torch.int32), # cum_seqlen_k: Tensor arg + lse.data_ptr() if lse is not None else None, + Float32(scale_softmax_log2), + Float32(scale_softmax), + Float32(scale_output), + skip_threshold_log2, + ws_left, + ws_right, + None, # skip_softmax_count + None, # total_softmax_count + q_4d, # q_tensor for env stream device detection + ) + else: + # CuTe native ABI: convert to cute tensors, pass iterators + explicit stream. + + # DSL FMHA kernel expects 4D tensor (B, S, H, D). + q_4d = q.unsqueeze(0) + k_4d = k.unsqueeze(0) + v_4d = v.unsqueeze(0) + o_4d = o.unsqueeze(0) + + is_fp8_in = q.dtype == torch.float8_e4m3fn + is_fp8_out = o.dtype == torch.float8_e4m3fn + if is_fp8_in: + q_cute = from_dlpack( + q_4d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=3) + q_cute.element_type = cutlass.Float8E4M3FN + k_cute = from_dlpack( + k_4d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=3) + k_cute.element_type = cutlass.Float8E4M3FN + v_cute = from_dlpack( + v_4d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=3) + v_cute.element_type = cutlass.Float8E4M3FN + else: + q_cute = from_dlpack(q_4d, assumed_align=16).mark_layout_dynamic( + leading_dim=3 + ) + k_cute = from_dlpack(k_4d, assumed_align=16).mark_layout_dynamic( + leading_dim=3 + ) + v_cute = from_dlpack(v_4d, assumed_align=16).mark_layout_dynamic( + leading_dim=3 + ) + if is_fp8_out: + o_cute = from_dlpack( + o_4d.view(torch.int8), assumed_align=16 + ).mark_layout_dynamic(leading_dim=3) + o_cute.element_type = cutlass.Float8E4M3FN + else: + o_cute = from_dlpack(o_4d, assumed_align=16).mark_layout_dynamic( + leading_dim=3 + ) + + cum_seqlen_q_cute = from_dlpack( + qo_indptr.to(torch.int32), assumed_align=16 + ).mark_layout_dynamic(leading_dim=0) + cum_seqlen_k_cute = from_dlpack( + kv_indptr.to(torch.int32), assumed_align=16 + ).mark_layout_dynamic(leading_dim=0) + + lse_iter = None + if lse is not None: + # TODO: lse's shape? + lse_cute = from_dlpack(lse, assumed_align=16).mark_layout_dynamic( + leading_dim=2 + ) + lse_iter = lse_cute.iterator + + stream = cuda_driver.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel_fn( + q_cute.iterator, + k_cute.iterator, + v_cute.iterator, + o_cute.iterator, + problem_size, + cum_seqlen_q_cute, + cum_seqlen_k_cute, + lse_iter, + Float32(scale_softmax_log2), + Float32(scale_softmax), + Float32(scale_output), + skip_threshold_log2, + ws_left, + ws_right, + None, # skip_softmax_count + None, # total_softmax_count + None, # q_tensor (unused, for TVM-FFI env stream) + stream, + ) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 4ec6a29e7d..27a0d34b93 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3692,6 +3692,7 @@ def trtllm_ragged_attention_deepseek( skip_softmax_threshold_scale_factor: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + backend: str = "trtllm-gen", ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters @@ -3742,6 +3743,12 @@ def trtllm_ragged_attention_deepseek( output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]] lse : Optional[torch.Tensor] lse tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1]] + backend : str + Attention backend to use. "trtllm-gen" (default) or "cute-dsl". + When backend="cute-dsl", query/key/value/out tensors must be + front-padded with max_seq_len rows of valid GPU memory before + index 0 (see ``cute_dsl_fmha_ragged_prefill`` for details). + This requirement will be removed in the next MR. Returns ------- @@ -3761,8 +3768,7 @@ def trtllm_ragged_attention_deepseek( if enable_pdl is None: enable_pdl = device_support_pdl(query.device) - run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention - sm_count = get_device_sm_count(query.device) + # --- Output allocation (shared by all backends) --- if out is None: # FP8 inputs produce bfloat16 output by default (TRT-LLM kernels # do not support FP8 output for ragged attention) @@ -3791,6 +3797,8 @@ def trtllm_ragged_attention_deepseek( "FP8 output is not supported for trtllm_ragged_attention_deepseek; " "use bfloat16 or float16 for out." ) + + # --- LSE allocation (shared by all backends) --- if return_lse and lse is None: lse = torch.empty( query.shape[0], @@ -3799,37 +3807,96 @@ def trtllm_ragged_attention_deepseek( dtype=torch.float32, ) - if isinstance(bmm1_scale, torch.Tensor): - assert bmm1_scale.dtype == torch.float32 - bmm1_scale = bmm1_scale * log2e - if isinstance(bmm2_scale, torch.Tensor): - assert bmm2_scale.dtype == torch.float32 + if backend == "cute-dsl": + from .attention.cute_dsl.fmha import cute_dsl_fmha_ragged_prefill + + import warnings + + # TODO: remove this warning when PDL support added + # TODO: support PDL for cute-dsl backend + if enable_pdl: + warnings.warn( + "cute-dsl backend does not support PDL yet (enable_pdl ignored)", + stacklevel=2, + ) + if attention_sinks is not None: + warnings.warn( + "cute-dsl backend does not support attention_sinks (ignored)", + stacklevel=2, + ) + _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float8_e4m3fn) + assert query.dtype in _SUPPORTED_DTYPES, ( + f"cute-dsl backend only supports {_SUPPORTED_DTYPES}, got {query.dtype}" + ) + # TODO: support device tensor scales to avoid D2H sync overhead + assert not isinstance(bmm1_scale, torch.Tensor), ( + "cute-dsl backend does not support device tensor bmm1_scale" + ) + assert not isinstance(bmm2_scale, torch.Tensor), ( + "cute-dsl backend does not support device tensor bmm2_scale" + ) + _bmm1 = bmm1_scale + _bmm2 = bmm2_scale + + # bmm1_scale = scale_q * scale_k * sm_scale (already fused by caller) + # bmm2_scale = scale_v + # Pass the fused value as sm_scale with scale_q=scale_k=1.0 + cute_dsl_fmha_ragged_prefill( + q=query, + k=key, + v=value, + o=out, + qo_indptr=cum_seq_lens_q, + kv_indptr=cum_seq_lens_kv, + is_causal=is_causal, + sm_scale=_bmm1, + window_left=window_left, + lse=lse if return_lse else None, + scale_q=1.0, + scale_k=1.0, + scale_v=_bmm2, + scale_o=1.0, + max_qo_len=max_q_len, + max_kv_len=max_kv_len, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + ) + else: + # --- trtllm-gen backend --- + run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention + sm_count = get_device_sm_count(query.device) + + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 + + workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() + run_func( + out, + query, + key, + value, + workspace_buffer, + seq_lens, + max_q_len, + max_kv_len, + bmm1_scale, + bmm2_scale, + o_sf_scale, + batch_size, + window_left, + cum_seq_lens_q, + cum_seq_lens_kv, + sm_count, + enable_pdl, + is_causal, + workspace_size, + attention_sinks, + skip_softmax_threshold_scale_factor, + lse, + ) - workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() - run_func( - out, - query, - key, - value, - workspace_buffer, - seq_lens, - max_q_len, - max_kv_len, - bmm1_scale, - bmm2_scale, - o_sf_scale, - batch_size, - window_left, - cum_seq_lens_q, - cum_seq_lens_kv, - sm_count, - enable_pdl, - is_causal, - workspace_size, - attention_sinks, - skip_softmax_threshold_scale_factor, - lse, - ) if return_lse: assert lse is not None, ( "lse assumed not None beyond this point when return_lse is True" diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 186f1e5072..43af25f485 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1552,6 +1552,7 @@ def test_trtllm_batch_decode_long_sequence_length( ) +@pytest.mark.parametrize("backend", ["trtllm-native", "cute-dsl"]) @pytest.mark.parametrize( "mla_dimensions", [deepseek_mla_dimensions, smaller_mla_dimensions] ) @@ -1563,6 +1564,7 @@ def test_trtllm_batch_decode_long_sequence_length( @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_gen_prefill( + backend: str, mla_dimensions: MLAHeadDimensions, batch_size: int, s_qo: int, @@ -1578,8 +1580,12 @@ def test_trtllm_gen_prefill( if s_qo > s_kv: pytest.skip("s_qo > s_kv, skipping test as causal") - num_qo_heads = num_kv_heads * head_grp_size head_dim_qk = mla_dimensions.qk_nope_head_dim + mla_dimensions.qk_rope_head_dim + if backend == "cute-dsl": + if head_dim_qk == 192: + pytest.skip("cute-dsl does not support bf16 with head_dim=192") + + num_qo_heads = num_kv_heads * head_grp_size head_dim_vo = mla_dimensions.v_head_dim seed = 0 @@ -1597,20 +1603,47 @@ def test_trtllm_gen_prefill( cumsum_s_qo = int(torch.sum(actual_seq_lens_q).item()) cumsum_s_kv = int(torch.sum(actual_seq_lens_kv).item()) - q = torch.randn( - cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 - ) - - k_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_qk), - device=device, - dtype=torch.bfloat16, - ) - v_cache = torch.randn( - (cumsum_s_kv, num_kv_heads, head_dim_vo), - device=device, - dtype=torch.bfloat16, - ) + # DSL FMHA varlen kernel uses negative pointer offsets, so tensors need + # front-padding of max_s elements to ensure valid GPU memory before data. + if backend == "cute-dsl": + q_full = torch.randn( + s_qo + cumsum_s_qo, + num_qo_heads, + head_dim_qk, + device=device, + dtype=torch.bfloat16, + ) + q = q_full[s_qo:] + k_full = torch.randn( + s_kv + cumsum_s_kv, + num_kv_heads, + head_dim_qk, + device=device, + dtype=torch.bfloat16, + ) + k_cache = k_full[s_kv:] + v_full = torch.randn( + s_kv + cumsum_s_kv, + num_kv_heads, + head_dim_vo, + device=device, + dtype=torch.bfloat16, + ) + v_cache = v_full[s_kv:] + else: + q = torch.randn( + cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 + ) + k_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_qk), + device=device, + dtype=torch.bfloat16, + ) + v_cache = torch.randn( + (cumsum_s_kv, num_kv_heads, head_dim_vo), + device=device, + dtype=torch.bfloat16, + ) # Initialize scale scale = float(1.0 / (head_dim_qk**0.5)) @@ -1655,7 +1688,17 @@ def test_trtllm_gen_prefill( kv_data_type=torch.bfloat16, ) output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True) - output = torch.empty_like(output_ref) + if backend == "cute-dsl": + output_full = torch.empty( + s_qo + cumsum_s_qo, + num_qo_heads, + head_dim_vo, + device=device, + dtype=output_ref.dtype, + ) + output = output_full[s_qo:] + else: + output = torch.empty_like(output_ref) bmm1_scale = scale bmm2_scale = 1.0 @@ -1683,6 +1726,7 @@ def test_trtllm_gen_prefill( True, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, out=output, + backend=backend, ) torch.testing.assert_close( output_trtllm, @@ -1698,9 +1742,177 @@ def test_trtllm_gen_prefill( ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + if backend == "trtllm-native": + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize("backend", ["cute-dsl"]) +@pytest.mark.parametrize( + "mla_dimensions", [deepseek_mla_dimensions, smaller_mla_dimensions] +) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("s_qo", [8192]) +@pytest.mark.parametrize("s_kv", [8192]) +@pytest.mark.parametrize("num_kv_heads", [128]) +@pytest.mark.parametrize("head_grp_size", [1]) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("skips_softmax", [False, True]) +def test_trtllm_gen_prefill_fp8( + backend: str, + mla_dimensions: MLAHeadDimensions, + batch_size: int, + s_qo: int, + s_kv: int, + num_kv_heads: int, + head_grp_size: int, + causal: bool, + skips_softmax: bool, +) -> None: + """Test cute-dsl prefill with FP8 (e4m3) input, bf16 output.""" + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] != 10: + pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") + + head_dim_qk = mla_dimensions.qk_nope_head_dim + mla_dimensions.qk_rope_head_dim + head_dim_vo = mla_dimensions.v_head_dim + num_qo_heads = num_kv_heads * head_grp_size + + seed = 0 + torch.manual_seed(seed) + device = "cuda:0" + + actual_seq_lens_q = torch.randint( + 1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + actual_seq_lens_kv = torch.randint( + s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device + ) + cumsum_s_qo = int(torch.sum(actual_seq_lens_q).item()) + cumsum_s_kv = int(torch.sum(actual_seq_lens_kv).item()) + + # FP8 scales + scale_q, scale_k, scale_v = 0.05, 0.04, 0.06 + + # Generate in float32, quantize to FP8 with front-padding + q_f32 = ( + torch.randn( + s_qo + cumsum_s_qo, + num_qo_heads, + head_dim_qk, + dtype=torch.float32, + device=device, + ) + * 0.1 + ) + k_f32 = ( + torch.randn( + s_kv + cumsum_s_kv, + num_kv_heads, + head_dim_qk, + dtype=torch.float32, + device=device, + ) + * 0.1 + ) + v_f32 = ( + torch.randn( + s_kv + cumsum_s_kv, + num_kv_heads, + head_dim_vo, + dtype=torch.float32, + device=device, + ) + * 0.1 + ) + + q = (q_f32 / scale_q).to(torch.float8_e4m3fn)[s_qo:] + k_cache = (k_f32 / scale_k).to(torch.float8_e4m3fn)[s_kv:] + v_cache = (v_f32 / scale_v).to(torch.float8_e4m3fn)[s_kv:] + + # Reference: dequantize and run bf16 attention + q_bf16 = (q.float() * scale_q).to(torch.bfloat16) + k_bf16 = (k_cache.float() * scale_k).to(torch.bfloat16) + v_bf16 = (v_cache.float() * scale_v).to(torch.bfloat16) + + qo_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_q.view(-1), dim=0), + ] + ).int() + kv_indptr = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), + ] + ).int() + + workspace_buffer, workspace_buffer_ref = create_workspace_buffers(device) + + # Reference via cutlass backend (bf16) + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer_ref, + kv_layout="NHD", + backend="cutlass", + ) + wrapper.plan( + qo_indptr, + kv_indptr, + num_qo_heads, + num_kv_heads, + head_dim_qk, + head_dim_vo=head_dim_vo, + causal=causal, + sm_scale=1.0 / (head_dim_qk**0.5), + q_data_type=torch.bfloat16, + kv_data_type=torch.bfloat16, + ) + output_ref, _ = wrapper.run(q_bf16, k_bf16, v_bf16, return_lse=True) + + # Output with front-padding + output_full = torch.empty( + s_qo + cumsum_s_qo, + num_qo_heads, + head_dim_vo, + device=device, + dtype=torch.bfloat16, + ) + output = output_full[s_qo:] + + scale = 1.0 / (head_dim_qk**0.5) + bmm1_scale = scale_q * scale_k * scale + bmm2_scale = scale_v + + # Using a tiny threshold should give the same result as normal attention. + skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None + + output_fp8, _ = flashinfer.prefill.trtllm_ragged_attention_deepseek( + q, + k_cache, + v_cache, + workspace_buffer, + actual_seq_lens_kv, + s_qo, + s_kv, + bmm1_scale, + bmm2_scale, + -1, + batch_size, + -1, + qo_indptr, + kv_indptr, + False, + causal, + True, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + out=output, + backend=backend, + ) + + torch.testing.assert_close(output_fp8, output_ref, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize("backend", ["trtllm-native", "cute-dsl"]) @pytest.mark.parametrize( "mla_dimensions", [deepseek_mla_dimensions, smaller_mla_dimensions] ) @@ -1712,6 +1924,7 @@ def test_trtllm_gen_prefill( @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_gen_prefill_bs1( + backend: str, mla_dimensions: MLAHeadDimensions, batch_size: int, s_qo: int, @@ -1722,6 +1935,7 @@ def test_trtllm_gen_prefill_bs1( skips_softmax: bool, ): test_trtllm_gen_prefill( + backend, mla_dimensions, batch_size, s_qo, diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py index fe83c45cc9..6dabf1884b 100644 --- a/tests/test_artifacts.py +++ b/tests/test_artifacts.py @@ -344,6 +344,28 @@ def test_get_subdir_file_list(monkeypatch, tmp_path): responses.GET, deepgemm_checksums_url, body=checksums_deepgemm, status=200 ) + # Mock DSL_FMHA checksums + directory index for the host cpu_arch. + # Pin to x86_64 so the test is deterministic regardless of the runner arch. + monkeypatch.setattr(artifacts, "_get_host_cpu_arch", lambda: "x86_64") + checksums_dsl_fmha = "aabbccdd11223344 cute_dsl_fmha_bf16_h128_causal_nonpersistent_varlen_tvmffi.so\n" + # Minimal directory index: an empty HTML page with no cubin/header hrefs. + # This avoids 404 retry overhead while still exercising the code path. + empty_dir_index = '
../
' + for sm_arch in artifact_paths.DSL_FMHA_ARCHS: + subdir = safe_urljoin(artifact_paths.DSL_FMHA, f"x86_64/{sm_arch}/") + responses.add( + responses.GET, + safe_urljoin(test_cubin_repository, safe_urljoin(subdir, "checksums.txt")), + body=checksums_dsl_fmha, + status=200, + ) + responses.add( + responses.GET, + safe_urljoin(test_cubin_repository, subdir), + body=empty_dir_index, + status=200, + ) + cubin_files = list(get_subdir_file_list()) # Extract just the file paths from the (path, checksum) tuples