Skip to content
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a50f52e
[cutedsl] Add modular attention package with FMHA and MLA kernels
pgera Mar 16, 2026
a501c5e
Remove monolithic kernels from original PR, separate benchmarks
pgera Mar 17, 2026
24b92d9
Fix causal mask boundary + PV accumulate bugs in modular attention ke…
pgera Mar 17, 2026
4f37d5a
Fix attention sink precision bugs: wrapper dtype mismatch and M_D_upd…
pgera Mar 17, 2026
a66b7e4
Fix sliding window mask, sink M_D_update scaling, and add comprehensi…
pgera Mar 17, 2026
93111ab
Remove MLA/decode files from prefill PR
pgera Mar 17, 2026
b82b7d7
Replace patch/pipeline.py with upstream cutlass.pipeline
pgera Mar 18, 2026
e950529
Address review feedback: add @cute.jit to get_trip_count, remove unus…
pgera Mar 18, 2026
6b38a2d
Simplify RESIDUAL_MASK check and remove redundant qo_head_idx guard (…
pgera Mar 18, 2026
3848236
Unify variant runtime data into single params mechanism with score_mo…
pgera Mar 19, 2026
776eac9
Update copyright year to 2026 for new CuTe DSL attention files
pgera Mar 19, 2026
8e83012
Update copyright year to 2026 for benchmark file
pgera Mar 19, 2026
1296216
Standardize license headers to abbreviated SPDX BSD-3-Clause
pgera Mar 19, 2026
8892d98
Add validation and robustness improvements to CuTe DSL attention (AI-…
pgera Mar 19, 2026
6473af0
Fix SigmoidAttention bias not converted to log-base-2 domain (AI-assi…
pgera Mar 24, 2026
06300de
Add transform_logits support: coordinate-free API, composable with sc…
pgera Apr 1, 2026
16d31a5
Port monolithic MLA decode kernel to modular role-based architecture …
pgera Apr 3, 2026
315380c
Port FP8 MLA decode to modular attention framework (AI-assisted)
pgera Apr 6, 2026
84adb73
Move modular FMHA prefill test to tests/attention/ (AI-assisted)
pgera Apr 6, 2026
8892f81
Add FP8 tests to modular MLA decode test suite (AI-assisted)
pgera Apr 6, 2026
d1ea471
Fix FP16 MLA decode 2x attenuation regression from FP8 port (AI-assis…
pgera Apr 6, 2026
0fba944
Delete monolithic MLA decode kernels, wire modular in place (AI-assis…
pgera Apr 7, 2026
0cf787e
Remove flashinfer.mla.cute_dsl shim, import modular directly (AI-assi…
pgera Apr 7, 2026
471f5f0
Wire cute-dsl backend into BatchPrefillWithRaggedKVCacheWrapper (AI-a…
pgera Apr 7, 2026
f9d7727
Merge branch 'main' into cutedsl-fmha-prefill
yzh119 Apr 7, 2026
05b0361
Address PR review: TVM-FFI compile path, @flashinfer_api, benchmark g…
pgera Apr 7, 2026
eb1068d
Merge branch 'main' into cutedsl-fmha-prefill
yzh119 Apr 7, 2026
4a3ceaa
Fix pre-commit: ruff check/format, mypy barrier assertions (AI-assisted)
pgera Apr 8, 2026
f87b0b2
Fix mypy type errors across CuTe DSL attention package (AI-assisted)
pgera Apr 8, 2026
58537d1
Add KV cache shape validation and document sliding window limitation …
pgera Apr 8, 2026
aa48df0
Fix sliding window mask for qo_len != kv_len (AI-assisted)
pgera Apr 8, 2026
1f89f93
Merge remote-tracking branch 'origin/main' into cutedsl-fmha-prefill
pgera Apr 8, 2026
7f3a907
Add @flashinfer_api to BatchMLADecodeCuteDSLWrapper (AI-assisted)
pgera Apr 8, 2026
472c864
Align prefill and MLA wrapper patterns (AI-assisted)
pgera Apr 8, 2026
e697268
Standardize license headers to abbreviated SPDX BSD-3-Clause (2026)
pgera Apr 8, 2026
1d4e6c6
Add is_cute_dsl_available() guard to benchmark and FMHA test (AI-assi…
pgera Apr 8, 2026
e74065a
Merge branch 'main' into cutedsl-fmha-prefill
yzh119 Apr 8, 2026
57e45ce
Add attention variant support to MLA decode and unify caching (AI-ass…
pgera Apr 9, 2026
2972bb3
Fix cute-dsl prefill out= contract and address PR review comments (AI…
pgera Apr 9, 2026
876b6c2
Merge remote-tracking branch 'origin/main' into cutedsl-fmha-prefill
pgera Apr 9, 2026
b359788
Guard against cute-dsl backend in paged KV cache wrapper (AI-assisted)
pgera Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions benchmarks/bench_blackwell_attention_cutedsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import sys

import numpy as np
import torch

import flashinfer
from flashinfer.cute_dsl.utils import is_cute_dsl_available
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import is_sm100a_supported

if not is_cute_dsl_available():
print("Skipping: nvidia-cutlass-dsl package not installed")
sys.exit(0)

from flashinfer.cute_dsl.attention import BatchPrefillCuteDSLWrapper


def bench_fmha_blackwell(
batch_size,
qkv_len,
num_heads,
head_dim,
causal,
dtype,
):
q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)

qo_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
kv_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
torch.empty(128 * 1024 * 1024, dtype=dtype, device="cuda"),
kv_layout="NHD",
backend="cutlass",
)
wrapper.plan(
qo_segment_offsets,
kv_segment_offsets,
num_heads,
num_heads,
head_dim,
head_dim_vo=head_dim,
causal=causal,
q_data_type=dtype,
kv_data_type=dtype,
)
o = wrapper.run(q, k, v)
measurements = bench_gpu_time(
Comment thread
pgera marked this conversation as resolved.
lambda: wrapper.run(q, k, v),
dry_run_time_ms=100,
repeat_time_ms=1000,
enable_cupti=True,
)
ms = np.median(measurements)

def flops(ms):
if causal:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
else:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9

def io(ms):
mem_size = (
q.numel() * q.element_size()
+ k.numel() * k.element_size()
+ v.numel() * v.element_size()
+ o.numel() * o.element_size()
)
return mem_size / ms / 1e6

print(
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s"
)


def bench_fmha_cutedsl(
batch_size,
qkv_len,
num_heads,
head_dim,
causal,
dtype,
sm_scale=None,
):
if sm_scale is None:
sm_scale = 1.0 / (head_dim**0.5)

q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)

qo_indptr = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)
kv_indptr = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
)

wrapper = BatchPrefillCuteDSLWrapper(
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
)
wrapper.plan(
qo_indptr,
kv_indptr,
num_heads,
num_heads,
head_dim,
head_dim_vo=head_dim,
causal=causal,
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
)
o = wrapper.run(q, k, v)
measurements = bench_gpu_time(
lambda: wrapper.run(q, k, v),
dry_run_time_ms=100,
repeat_time_ms=1000,
enable_cupti=True,
)
ms = np.median(measurements)

def flops(ms):
if causal:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
else:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9

def io(ms):
mem_size = (
q.numel() * q.element_size()
+ k.numel() * k.element_size()
+ v.numel() * v.element_size()
+ o.numel() * o.element_size()
)
return mem_size / ms / 1e6

print(
f"bench_fmha_cutedsl (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s"
)


if __name__ == "__main__":
if not is_sm100a_supported(torch.device("cuda")):
print("Skipping: requires SM100+")
sys.exit(0)

configs = [
Comment thread
pgera marked this conversation as resolved.
(128, 512, 32, 128, True, torch.bfloat16),
(64, 1024, 32, 128, True, torch.bfloat16),
(32, 2048, 32, 128, True, torch.bfloat16),
(16, 4096, 32, 128, True, torch.bfloat16),
(8, 8192, 32, 128, True, torch.bfloat16),
(4, 16384, 32, 128, True, torch.bfloat16),
(2, 32768, 32, 128, True, torch.bfloat16),
(1, 65536, 32, 128, True, torch.bfloat16),
]

print("=== CUTLASS (via BatchPrefillWithRaggedKVCacheWrapper) ===")
for cfg in configs:
bench_fmha_blackwell(*cfg)
print()
print("=== CuTe DSL (via BatchPrefillCuteDSLWrapper) ===")
for cfg in configs:
bench_fmha_cutedsl(*cfg)
79 changes: 79 additions & 0 deletions flashinfer/cute_dsl/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""Modular attention kernels for CuTe DSL.

Kernels live at the top level of this package.
Building blocks (config, tmem_layout, roles, fusion, scheduler, wrappers) are
one level below in subdirectories.
"""

# Kernels
from .prefill import BlackwellFusedMultiHeadAttentionForward
from .mla_decode import BlackwellMultiLatentAttentionForward
from .mla_decode_fp8 import BlackwellMultiLatentAttentionForwardFP8

# Building blocks — FMHA prefill
from .config import AttentionConfig, AttentionFusion, HeadMapping, TileBounds
from .tmem_layout import TmemLayout
from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE
from .pipeline_topology import (
PipelineEdge,
PipelineType,
PipelineTopology,
make_prefill_topology,
make_mla_topology,
make_mla_fp8_topology,
)
from .mainloop_spec import (
MainloopSpec,
make_prefill_mainloop_spec,
MLAMainloopSpec,
make_mla_mainloop_spec,
make_mla_fp8_mainloop_spec,
)
from .fusion.mask import MaskType
from .fusion.variant import (
tanh_approx,
AttentionVariant,
StandardAttention,
AttentionWithSink,
SigmoidAttention,
SigmoidTanhAttention,
ALiBiAttention,
RPEAttention,
SoftCappingAttention,
)
from .scheduler.persistent import (
FmhaStaticTileScheduler,
FmhaStaticTileSchedulerParams,
create_fmha_static_tile_scheduler,
create_fmha_static_tile_scheduler_params,
)

# Building blocks — MLA decode
from .mla_config import MLAConfig
from .mla_warp_schedule import (
MLAWarpSchedule,
MLA_DECODE_SCHEDULE,
MLAWarpScheduleFP8,
MLA_DECODE_FP8_SCHEDULE,
)
from .scheduler.mla_persistent import (
MLAStaticTileScheduler,
MLAStaticTileSchedulerParams,
create_mla_static_tile_scheduler,
create_mla_static_tile_scheduler_params,
mla_get_split_kv,
mla_get_split_kv_simplified,
mla_get_workspace_size,
)

# Wrappers
from .wrappers.batch_prefill import (
BatchPrefillCuteDSLWrapper,
)
from .wrappers.batch_mla import (
BatchMLADecodeCuteDSLWrapper,
cute_dsl_mla_decode,
)
Loading
Loading