Skip to content

Commit bb873d2

Browse files
limin2021claudeyzh119
authored
feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin (#3039)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> feat: Integrate CuTe DSL FMHA cubin kernels into FlashInfer prefill backend **Summary** - Integrate pre-compiled CuTe DSL FMHA kernels (Blackwell SM100/SM103/SM110) into FlashInfer's prefill attention backend - Load AOT-compiled .so cubins from NVIDIA artifactory at runtime, no JIT compilation needed - Route through trtllm_ragged_attention_deepseek() API with backend="cute-dsl" **Key features** - Dtype support: FP16, BF16, FP8 (E4M3) input with mixed-precision output (E4M3→BF16) - Head dimensions: 32, 64, 128, 192 (192 for FP8 only) - Varlen ragged prefill: variable-length sequences via cumulative seqlen tensors - TVM-FFI ABI: all variants use TVM-FFI for kernel invocation - Skip-softmax sparsity: optional skip-softmax optimization for sparse attention - LSE output: optional log-sum-exp output for numerically stable multi-pass attention - Causal & non-causal masking: both modes supported (all varlen variants use non-persistent scheduling) - Multi-arch cubin loading: per-CPU-arch (x86_64/aarch64) and per-SM-arch artifact paths - Checksum verification: SHA256 integrity check on downloaded .so files **Files changed** - flashinfer/attention_dsl/cute_dsl/fmha.py — kernel loading, variant selection, ragged prefill entry point - flashinfer/artifacts.py — artifact paths and checksums for DSL FMHA (x86_64 + aarch64 layout) - flashinfer/prefill.py — trtllm_ragged_attention_deepseek() cute-dsl backend integration **Test plan** - test_trtllm_gen_attention.py::test_trtllm_gen_prefill -k "cute-dsl" passes - Benchmark via bench_cute_dsl_ragged.sh on target hardware - Verify cubin download + checksum verification on clean install **Performance** **Setup:** B200 (sm_100a), causal, H_q=H_k=128, tested using FI benchmark (CUDA Graph, cupti) FP8 e4m3 (D=192): | Shape (B×S_q×S_kv) | cute-dsl (ms) | trtllm-native (ms) | TFLOPS (dsl/native) | Speedup | |---------------------|--------------|--------------------|--------------------|---------| | 1×8K×8K | 1.521 | 1.619 | 1808 / 1698 | **+6.4%** | | 1×8K×32K | 8.466 | 9.451 | 2273 / 2036 | **+11.6%** | | 1×8K×64K | 17.796 | 19.869 | 2317 / 2075 | **+11.7%** | | 4×512×82K | 6.397 | 7.286 | 2142 / 1880 | **+13.9%** | | 4×1K×82K | 12.285 | 13.834 | 2224 / 1975 | **+12.6%** | FP8 e4m3 (D=128): | Shape (B×S_q×S_kv) | cute-dsl (ms) | trtllm-native (ms) | TFLOPS (dsl/native) | Speedup | |---------------------|--------------|--------------------|--------------------|---------| | 1×8K×8K | 1.484 | 1.560 | 1481 / 1410 | **+5.1%** | | 1×8K×32K | 7.666 | 8.998 | 2008 / 1711 | **+17.4%** | | 1×8K×64K | 16.074 | 18.606 | 2052 / 1773 | **+15.8%** | | 4×512×82K | 5.735 | 6.460 | 1911 / 1697 | **+12.6%** | | 4×1K×82K | 11.066 | 12.451 | 1975 / 1755 | **+12.5%** | BF16 (D=128): | Shape (B×S_q×S_kv) | cute-dsl (ms) | trtllm-native (ms) | TFLOPS (dsl/native) | Speedup | |---------------------|--------------|--------------------|--------------------|---------| | 1×8K×8K | 1.737 | 1.764 | 1266 / 1247 | **+1.6%** | | 1×8K×32K | 10.094 | 10.992 | 1525 / 1400 | **+8.9%** | | 1×8K×64K | 21.745 | 23.000 | 1517 / 1434 | **+5.8%** | | 4×512×82K | 8.457 | 8.513 | 1296 / 1288 | **+0.7%** | | 4×1K×82K | 15.773 | 16.052 | 1385 / 1361 | **+1.8%** | **TODO** (1) support scalar as tensor dtype. (2) support pdl (3) remove front-padding for q/k/v/o tensors ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent a265b4e commit bb873d2

10 files changed

Lines changed: 1096 additions & 66 deletions

File tree

benchmarks/routines/attention.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,18 +1646,40 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
16461646

16471647
cumsum_s_qo = torch.sum(actual_seq_lens_q)
16481648
cumsum_s_kv = torch.sum(actual_seq_lens_kv)
1649-
q = torch.randn(
1650-
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype
1649+
1650+
# Front-padding for cute-dsl varlen kernel: the persistent varlen kernel
1651+
# applies a negative pointer offset (-max_s * H * D), so there must be
1652+
# valid GPU memory before the data start.
1653+
front_pad_q = s_qo if "cute-dsl" in backends else 0
1654+
front_pad_kv = s_kv if "cute-dsl" in backends else 0
1655+
1656+
q_full = torch.randn(
1657+
front_pad_q + cumsum_s_qo,
1658+
num_qo_heads,
1659+
head_dim_qk,
1660+
device=device,
1661+
dtype=q_init_dtype,
16511662
)
1663+
q = q_full[front_pad_q:]
16521664
if args.verbose >= 2:
16531665
print(f"[VVERBOSE] {q.shape = }")
16541666

1655-
k = torch.randn(
1656-
cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype
1667+
k_full = torch.randn(
1668+
front_pad_kv + cumsum_s_kv,
1669+
num_kv_heads,
1670+
head_dim_qk,
1671+
device=device,
1672+
dtype=kv_init_dtype,
16571673
)
1658-
v = torch.randn(
1659-
cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype
1674+
k = k_full[front_pad_kv:]
1675+
v_full = torch.randn(
1676+
front_pad_kv + cumsum_s_kv,
1677+
num_kv_heads,
1678+
head_dim_vo,
1679+
device=device,
1680+
dtype=kv_init_dtype,
16601681
)
1682+
v = v_full[front_pad_kv:]
16611683

16621684
block_tables = None
16631685

@@ -1815,14 +1837,18 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
18151837
v = (v / v_scale).to(kv_dtype)
18161838

18171839
trtllm_out = None
1818-
if "trtllm-native" in backends:
1819-
trtllm_out = torch.empty(
1820-
q.shape[0],
1840+
if "trtllm-native" in backends or "cute-dsl" in backends:
1841+
# cute-dsl varlen kernel uses negative pointer offsets on output,
1842+
# so front-pad like Q/K/V.
1843+
out_pad = front_pad_q if "cute-dsl" in backends else 0
1844+
trtllm_out_full = torch.empty(
1845+
out_pad + q.shape[0],
18211846
q.shape[1],
18221847
v.shape[2],
18231848
device=q.device,
18241849
dtype=out_dtype,
18251850
)
1851+
trtllm_out = trtllm_out_full[out_pad:]
18261852

18271853
def run_backend_wrapper(
18281854
backend,
@@ -1843,6 +1869,31 @@ def run_backend_wrapper(
18431869
):
18441870
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]:
18451871
return backend_wrappers[backend].run_return_lse(q, k, v)[0]
1872+
elif backend == "cute-dsl":
1873+
_q_scale = q_scale if q_scale is not None else 1.0
1874+
_k_scale = k_scale if k_scale is not None else 1.0
1875+
_v_scale = v_scale if v_scale is not None else 1.0
1876+
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
1877+
query=q,
1878+
key=k,
1879+
value=v,
1880+
workspace_buffer=workspace_buffer,
1881+
seq_lens=actual_seq_lens_kv_device,
1882+
max_q_len=s_qo,
1883+
max_kv_len=s_kv,
1884+
bmm1_scale=_q_scale * _k_scale * scale,
1885+
bmm2_scale=_v_scale,
1886+
o_sf_scale=-1,
1887+
batch_size=batch_size,
1888+
window_left=-1,
1889+
cum_seq_lens_q=qo_indptr,
1890+
cum_seq_lens_kv=kv_indptr,
1891+
enable_pdl=False,
1892+
is_causal=causal,
1893+
return_lse=True,
1894+
out=trtllm_out,
1895+
backend="cute-dsl",
1896+
)[0]
18461897
elif backend == "cudnn":
18471898
# cuDNN uses wrapper API
18481899
return backend_wrappers[backend].run(q, k, v)

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,22 @@ def dtype_str_to_torch_dtype(dtype_str):
335335
"8.6": ["fa2", "cudnn", "cudnn-native"],
336336
"8.9": ["fa2", "cudnn", "cudnn-native"],
337337
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
338-
"10.0": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"],
339-
"10.3": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"],
338+
"10.0": [
339+
"fa2",
340+
"cudnn",
341+
"cudnn-native",
342+
"cutlass",
343+
"cute-dsl",
344+
"trtllm-native",
345+
],
346+
"10.3": [
347+
"fa2",
348+
"cudnn",
349+
"cudnn-native",
350+
"cutlass",
351+
"cute-dsl",
352+
"trtllm-native",
353+
],
340354
"12.0": ["fa2", "cudnn", "cudnn-native"],
341355
"12.1": ["fa2", "cudnn", "cudnn-native"],
342356
},

flashinfer/artifacts.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ class ArtifactPath:
145145
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
146146
# For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py
147147
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"
148+
DSL_FMHA: str = "c770c91cb0d991b7828fc85d2253a62f0d356b6c/fmha/cute-dsl/"
149+
DSL_FMHA_ARCHS: tuple[str, ...] = ("sm_100a", "sm_103a", "sm_110a")
148150

149151

150152
class CheckSumHash:
@@ -164,11 +166,32 @@ class CheckSumHash:
164166
TRTLLM_GEN_GEMM: str = (
165167
"64b7114a429ea153528dd4d4b0299363d7320964789eb5efaefec66f301523c7"
166168
)
169+
# SHA256 of the checksums.txt manifest file per cpu-arch/sm-arch,
170+
# NOT hashes of individual kernel .so files.
171+
DSL_FMHA_CHECKSUMS: dict[str, dict[str, str]] = {
172+
"x86_64": {
173+
"sm_100a": "9533536698cdc256d897fffb3114de317076654ff8630ff283d850cc3dc96d86",
174+
"sm_103a": "927e1954f1d45b0ee876f139084e4facdfcc87e86f4d30cb92d5c33698d4c2d6",
175+
"sm_110a": "277b1dceaab2081e3def37cf997280a3f2c3ac515d22b80be141253c0278b8b5",
176+
},
177+
"aarch64": {
178+
"sm_100a": "b48ed0bcc9bad4afd33e0784c8c9eb9e13e782afe197816b1d0747b11759493e",
179+
"sm_103a": "bace619a560f3ce52ad6ba105fffb8ea8629fe57885a90892c9e15a7122467e1",
180+
"sm_110a": "d8369bcfa443bfd791cd014e3b030d378f00a975db8278eebd5b2fb529e3257d",
181+
},
182+
}
167183
map_checksums: dict[str, str] = {
168184
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
169185
safe_urljoin(ArtifactPath.TRTLLM_GEN_BMM, "checksums.txt"): TRTLLM_GEN_BMM,
170186
safe_urljoin(ArtifactPath.DEEPGEMM, "checksums.txt"): DEEPGEMM,
171187
safe_urljoin(ArtifactPath.TRTLLM_GEN_GEMM, "checksums.txt"): TRTLLM_GEN_GEMM,
188+
**{
189+
safe_urljoin(
190+
ArtifactPath.DSL_FMHA, f"{cpu_arch}/{sm_arch}/checksums.txt"
191+
): sha
192+
for cpu_arch, sm_checksums in DSL_FMHA_CHECKSUMS.items()
193+
for sm_arch, sha in sm_checksums.items()
194+
},
172195
}
173196

174197

@@ -191,14 +214,30 @@ def get_checksums(subdirs):
191214
return checksums
192215

193216

217+
def _get_host_cpu_arch() -> str:
218+
"""Return CPU architecture string matching artifactory layout."""
219+
import platform
220+
221+
machine = platform.machine()
222+
if machine in ("aarch64", "arm64"):
223+
return "aarch64"
224+
return "x86_64"
225+
226+
194227
def get_subdir_file_list() -> Generator[tuple[str, str], None, None]:
195228
base = FLASHINFER_CUBINS_REPOSITORY
229+
cpu_arch = _get_host_cpu_arch()
196230

197231
cubin_dirs = [
198232
ArtifactPath.TRTLLM_GEN_FMHA,
199233
ArtifactPath.TRTLLM_GEN_BMM,
200234
ArtifactPath.TRTLLM_GEN_GEMM,
201235
ArtifactPath.DEEPGEMM,
236+
# DSL FMHA: per cpu-arch and sm-arch subdirectories
237+
*(
238+
safe_urljoin(ArtifactPath.DSL_FMHA, f"{cpu_arch}/{arch}/")
239+
for arch in ArtifactPath.DSL_FMHA_ARCHS
240+
),
202241
]
203242

204243
# Get checksums of all files

flashinfer/attention/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from ._core import * # noqa: F401,F403
18+
from ._core import BatchAttention, BatchAttentionWithAttentionSinkWrapper
19+
20+
__all__ = [
21+
"BatchAttention",
22+
"BatchAttentionWithAttentionSinkWrapper",
23+
]
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@
2020

2121
import torch
2222

23-
from .api_logging import flashinfer_api
24-
from .jit import gen_batch_attention_module
25-
from .utils import (
23+
from ..api_logging import flashinfer_api
24+
from ..jit import gen_batch_attention_module
25+
from ..utils import (
2626
MaskMode,
2727
PosEncodingMode,
2828
TensorLayout,
2929
_check_kv_layout,
3030
_unpack_paged_kv_cache,
3131
determine_attention_backend,
3232
)
33-
from .prefill import BatchPrefillWithPagedKVCacheWrapper
34-
from .jit.attention.variants import attention_sink_decl
35-
from .jit.utils import filename_safe_dtype_map
33+
from ..prefill import BatchPrefillWithPagedKVCacheWrapper
34+
from ..jit.attention.variants import attention_sink_decl
35+
from ..jit.utils import filename_safe_dtype_map
3636

3737

3838
@functools.cache
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) 2026 by FlashInfer team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
CuTe DSL Attention Kernels (Cubin Distribution)
16+
================================================
17+
18+
Pre-compiled FMHA kernels loaded via ExternalBinaryModule.
19+
"""
20+
21+
from flashinfer.cute_dsl.utils import is_cute_dsl_available
22+
23+
if is_cute_dsl_available():
24+
from .fmha import (
25+
get_cute_dsl_fmha_kernel,
26+
cute_dsl_fmha_ragged_prefill,
27+
)
28+
29+
__all__ = [
30+
"is_cute_dsl_available",
31+
"get_cute_dsl_fmha_kernel",
32+
"cute_dsl_fmha_ragged_prefill",
33+
]
34+
else:
35+
__all__ = [
36+
"is_cute_dsl_available",
37+
]

0 commit comments

Comments
 (0)