Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3010812
feat: Integrate CuTe DSL FMHA cubin kernels into prefill backend
limin2021 Apr 12, 2026
6fd8c0b
feat: Add FP8 support and fix TMA padding for DSL FMHA prefill
limin2021 Apr 12, 2026
19107f6
feat: Add TVM-FFI support for DSL FMHA prefill kernels
limin2021 Apr 12, 2026
70a6186
refactor: Consolidate DSL FMHA prefill tests, add ragged shape coverage
limin2021 Apr 12, 2026
9147885
fix: Use front-padding for ragged varlen tensors (match DSL example)
limin2021 Apr 12, 2026
27cd0aa
feat: Add cute-dsl to benchmark framework, fix CUDA graph compatibility
limin2021 Apr 12, 2026
7162590
fix: Preload DSL kernel in plan(), fix output dtype and CUDA graph co…
limin2021 Apr 13, 2026
f574404
feat: Route cute-dsl through trtllm_ragged_attention_deepseek, add LS…
limin2021 Apr 13, 2026
5ddbd31
fix: Set is_persistent=not is_causal in cute_dsl_fmha_prefill to matc…
limin2021 Apr 13, 2026
8a49143
refactor: Remove cute-dsl from single_prefill and BatchPrefillWithRag…
limin2021 Apr 13, 2026
05fdbdf
refactor: Remove enable_tvm_ffi=False tests, fix to tvm_ffi=True only
limin2021 Apr 13, 2026
7deedef
feat: Support skip-softmax sparsity in cute-dsl FMHA backend
limin2021 Apr 13, 2026
df72fc9
fix: Misc cleanups — revert unnecessary prefill.py changes, fix copyr…
limin2021 Apr 13, 2026
f55b71b
fix: Use os.path.join for artifact path to avoid absolute path when p…
limin2021 Apr 13, 2026
718586e
feat: Wire up DSL FMHA cubin loading from artifactory
limin2021 Apr 15, 2026
dcebe3c
update: DSL FMHA artifact path and checksums for multi-arch CI build
limin2021 Apr 15, 2026
44db80c
update: Add aarch64 checksums and arch-aware artifact paths for DSL FMHA
limin2021 Apr 16, 2026
df6f13e
refactor: Move attention.py into attention/ package, rename attention…
limin2021 Apr 16, 2026
61a8665
update: Simplify DSL FMHA test — FP8 only, update docstring
limin2021 Apr 16, 2026
4f78691
fix: Review feedback — dedup _get_cpu_arch, document front-padding re…
limin2021 Apr 16, 2026
6191a7a
fix: Add dtype validation for cute-dsl backend, remove duplicate head…
limin2021 Apr 16, 2026
16e8253
feat: Add FP8 prefill test, consolidate DSL FMHA tests
limin2021 Apr 16, 2026
14af985
fix: Rename backend "trtllm-native" to "trtllm-gen" in trtllm_ragged_…
limin2021 Apr 16, 2026
4a671bc
Merge origin/main into integrate_dsl_cubin_fmha
limin2021 Apr 20, 2026
ae77982
update: DSL FMHA artifact path and checksums to latest cubin release
limin2021 Apr 20, 2026
2207c68
test: Mock DSL_FMHA checksums + dir index in test_get_subdir_file_list
limin2021 Apr 20, 2026
fdfefa3
Merge branch 'main' into integrate_dsl_cubin_fmha
yzh119 Apr 21, 2026
e452cec
fix: Include gpu_arch in get_cute_dsl_fmha_kernel cache key
limin2021 Apr 21, 2026
0de6f44
Merge branch 'main' into integrate_dsl_cubin_fmha
yzh119 Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,18 +1646,40 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):

cumsum_s_qo = torch.sum(actual_seq_lens_q)
cumsum_s_kv = torch.sum(actual_seq_lens_kv)
q = torch.randn(
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=q_init_dtype

# Front-padding for cute-dsl varlen kernel: the persistent varlen kernel
# applies a negative pointer offset (-max_s * H * D), so there must be
# valid GPU memory before the data start.
front_pad_q = s_qo if "cute-dsl" in backends else 0
front_pad_kv = s_kv if "cute-dsl" in backends else 0

q_full = torch.randn(
front_pad_q + cumsum_s_qo,
num_qo_heads,
head_dim_qk,
device=device,
dtype=q_init_dtype,
)
q = q_full[front_pad_q:]
if args.verbose >= 2:
print(f"[VVERBOSE] {q.shape = }")

k = torch.randn(
cumsum_s_kv, num_kv_heads, head_dim_qk, device=device, dtype=kv_init_dtype
k_full = torch.randn(
front_pad_kv + cumsum_s_kv,
num_kv_heads,
head_dim_qk,
device=device,
dtype=kv_init_dtype,
)
v = torch.randn(
cumsum_s_kv, num_kv_heads, head_dim_vo, device=device, dtype=kv_init_dtype
k = k_full[front_pad_kv:]
v_full = torch.randn(
front_pad_kv + cumsum_s_kv,
num_kv_heads,
head_dim_vo,
device=device,
dtype=kv_init_dtype,
)
v = v_full[front_pad_kv:]

block_tables = None

Expand Down Expand Up @@ -1751,13 +1773,13 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
# Prepare wrappers
backend_wrappers = {}
for backend in backends:
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]:
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen", "cute-dsl"]:
backend_wrappers[backend] = (
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer,
"NHD",
use_cuda_graph=is_cuda_graph_compatible
if backend != "fa2"
if backend not in ["fa2"]
else False,
qo_indptr_buf=qo_indptr,
kv_indptr_buf=kv_indptr,
Expand Down Expand Up @@ -1843,6 +1865,8 @@ def run_backend_wrapper(
):
if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]:
return backend_wrappers[backend].run_return_lse(q, k, v)[0]
elif backend == "cute-dsl":
return backend_wrappers[backend].run(q, k, v)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
elif backend == "cudnn":
# cuDNN uses wrapper API
return backend_wrappers[backend].run(q, k, v)
Expand Down Expand Up @@ -1933,7 +1957,7 @@ def run_backend_wrapper(
repeat_iters=args.num_iters,
sleep_after_run=True,
enable_cupti=args.use_cupti,
use_cuda_graph=(is_cuda_graph_compatible and cur_backend != "fa2"),
use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ["fa2"]),
cold_l2_cache=True,
input_args=(
cur_backend,
Expand Down
20 changes: 18 additions & 2 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,24 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": ["fa2", "cudnn", "cudnn-native"],
"8.9": ["fa2", "cudnn", "cudnn-native"],
"9.0": ["fa2", "fa3", "cudnn", "cudnn-native"],
"10.0": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"],
"10.3": ["fa2", "cudnn", "cudnn-native", "cutlass", "trtllm-native"],
"10.0": [
"fa2",
"fa3",
"cudnn",
"cudnn-native",
"cutlass",
"cute-dsl",
"trtllm-native",
],
"10.3": [
"fa2",
"fa3",
"cudnn",
"cudnn-native",
"cutlass",
"cute-dsl",
"trtllm-native",
],
"12.0": ["fa2", "cudnn", "cudnn-native"],
"12.1": ["fa2", "cudnn", "cudnn-native"],
},
Expand Down
19 changes: 19 additions & 0 deletions flashinfer/attention_dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2025 by FlashInfer team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FlashInfer Attention DSL Module
================================

CuTe DSL attention kernel implementations (cubin distribution).
"""
39 changes: 39 additions & 0 deletions flashinfer/attention_dsl/cute_dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025 by FlashInfer team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CuTe DSL Attention Kernels (Cubin Distribution)
================================================

Pre-compiled FMHA kernels loaded via ExternalBinaryModule.
"""

from flashinfer.cute_dsl.utils import is_cute_dsl_available

if is_cute_dsl_available():
from .fmha import (
get_cute_dsl_fmha_kernel,
cute_dsl_fmha_prefill,
cute_dsl_fmha_ragged_prefill,
)

__all__ = [
"is_cute_dsl_available",
"get_cute_dsl_fmha_kernel",
"cute_dsl_fmha_prefill",
"cute_dsl_fmha_ragged_prefill",
]
else:
__all__ = [
"is_cute_dsl_available",
]
Loading
Loading