-
Notifications
You must be signed in to change notification settings - Fork 920
[CuTe DSL] Add modular FMHA prefill and MLA decode attention kernels #2805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
a50f52e
[cutedsl] Add modular attention package with FMHA and MLA kernels
pgera a501c5e
Remove monolithic kernels from original PR, separate benchmarks
pgera 24b92d9
Fix causal mask boundary + PV accumulate bugs in modular attention ke…
pgera 4f37d5a
Fix attention sink precision bugs: wrapper dtype mismatch and M_D_upd…
pgera a66b7e4
Fix sliding window mask, sink M_D_update scaling, and add comprehensi…
pgera 93111ab
Remove MLA/decode files from prefill PR
pgera b82b7d7
Replace patch/pipeline.py with upstream cutlass.pipeline
pgera e950529
Address review feedback: add @cute.jit to get_trip_count, remove unus…
pgera 6b38a2d
Simplify RESIDUAL_MASK check and remove redundant qo_head_idx guard (…
pgera 3848236
Unify variant runtime data into single params mechanism with score_mo…
pgera 776eac9
Update copyright year to 2026 for new CuTe DSL attention files
pgera 8e83012
Update copyright year to 2026 for benchmark file
pgera 1296216
Standardize license headers to abbreviated SPDX BSD-3-Clause
pgera 8892d98
Add validation and robustness improvements to CuTe DSL attention (AI-…
pgera 6473af0
Fix SigmoidAttention bias not converted to log-base-2 domain (AI-assi…
pgera 06300de
Add transform_logits support: coordinate-free API, composable with sc…
pgera 16d31a5
Port monolithic MLA decode kernel to modular role-based architecture …
pgera 315380c
Port FP8 MLA decode to modular attention framework (AI-assisted)
pgera 84adb73
Move modular FMHA prefill test to tests/attention/ (AI-assisted)
pgera 8892f81
Add FP8 tests to modular MLA decode test suite (AI-assisted)
pgera d1ea471
Fix FP16 MLA decode 2x attenuation regression from FP8 port (AI-assis…
pgera 0fba944
Delete monolithic MLA decode kernels, wire modular in place (AI-assis…
pgera 0cf787e
Remove flashinfer.mla.cute_dsl shim, import modular directly (AI-assi…
pgera 471f5f0
Wire cute-dsl backend into BatchPrefillWithRaggedKVCacheWrapper (AI-a…
pgera f9d7727
Merge branch 'main' into cutedsl-fmha-prefill
yzh119 05b0361
Address PR review: TVM-FFI compile path, @flashinfer_api, benchmark g…
pgera eb1068d
Merge branch 'main' into cutedsl-fmha-prefill
yzh119 4a3ceaa
Fix pre-commit: ruff check/format, mypy barrier assertions (AI-assisted)
pgera f87b0b2
Fix mypy type errors across CuTe DSL attention package (AI-assisted)
pgera 58537d1
Add KV cache shape validation and document sliding window limitation …
pgera aa48df0
Fix sliding window mask for qo_len != kv_len (AI-assisted)
pgera 1f89f93
Merge remote-tracking branch 'origin/main' into cutedsl-fmha-prefill
pgera 7f3a907
Add @flashinfer_api to BatchMLADecodeCuteDSLWrapper (AI-assisted)
pgera 472c864
Align prefill and MLA wrapper patterns (AI-assisted)
pgera e697268
Standardize license headers to abbreviated SPDX BSD-3-Clause (2026)
pgera 1d4e6c6
Add is_cute_dsl_available() guard to benchmark and FMHA test (AI-assi…
pgera e74065a
Merge branch 'main' into cutedsl-fmha-prefill
yzh119 57e45ce
Add attention variant support to MLA decode and unify caching (AI-ass…
pgera 2972bb3
Fix cute-dsl prefill out= contract and address PR review comments (AI…
pgera 876b6c2
Merge remote-tracking branch 'origin/main' into cutedsl-fmha-prefill
pgera b359788
Guard against cute-dsl backend in paged KV cache wrapper (AI-assisted)
pgera File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,185 @@ | ||
| # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| import sys | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| import flashinfer | ||
| from flashinfer.cute_dsl.utils import is_cute_dsl_available | ||
| from flashinfer.testing.utils import bench_gpu_time | ||
| from flashinfer.utils import is_sm100a_supported | ||
|
|
||
| if not is_cute_dsl_available(): | ||
| print("Skipping: nvidia-cutlass-dsl package not installed") | ||
| sys.exit(0) | ||
|
|
||
| from flashinfer.cute_dsl.attention import BatchPrefillCuteDSLWrapper | ||
|
|
||
|
|
||
| def bench_fmha_blackwell( | ||
| batch_size, | ||
| qkv_len, | ||
| num_heads, | ||
| head_dim, | ||
| causal, | ||
| dtype, | ||
| ): | ||
| q = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| k = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| v = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
|
|
||
| qo_segment_offsets = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
| kv_segment_offsets = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
| wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( | ||
| torch.empty(128 * 1024 * 1024, dtype=dtype, device="cuda"), | ||
| kv_layout="NHD", | ||
| backend="cutlass", | ||
| ) | ||
| wrapper.plan( | ||
| qo_segment_offsets, | ||
| kv_segment_offsets, | ||
| num_heads, | ||
| num_heads, | ||
| head_dim, | ||
| head_dim_vo=head_dim, | ||
| causal=causal, | ||
| q_data_type=dtype, | ||
| kv_data_type=dtype, | ||
| ) | ||
| o = wrapper.run(q, k, v) | ||
| measurements = bench_gpu_time( | ||
| lambda: wrapper.run(q, k, v), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| enable_cupti=True, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| def flops(ms): | ||
| if causal: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 | ||
| else: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 | ||
|
|
||
| def io(ms): | ||
| mem_size = ( | ||
| q.numel() * q.element_size() | ||
| + k.numel() * k.element_size() | ||
| + v.numel() * v.element_size() | ||
| + o.numel() * o.element_size() | ||
| ) | ||
| return mem_size / ms / 1e6 | ||
|
|
||
| print( | ||
| f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s" | ||
| ) | ||
|
|
||
|
|
||
| def bench_fmha_cutedsl( | ||
| batch_size, | ||
| qkv_len, | ||
| num_heads, | ||
| head_dim, | ||
| causal, | ||
| dtype, | ||
| sm_scale=None, | ||
| ): | ||
| if sm_scale is None: | ||
| sm_scale = 1.0 / (head_dim**0.5) | ||
|
|
||
| q = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| k = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
| v = torch.randn( | ||
| batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda" | ||
| ) | ||
|
|
||
| qo_indptr = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
| kv_indptr = ( | ||
| torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len | ||
| ) | ||
|
|
||
| wrapper = BatchPrefillCuteDSLWrapper( | ||
| torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8), | ||
| ) | ||
| wrapper.plan( | ||
| qo_indptr, | ||
| kv_indptr, | ||
| num_heads, | ||
| num_heads, | ||
| head_dim, | ||
| head_dim_vo=head_dim, | ||
| causal=causal, | ||
| sm_scale=sm_scale, | ||
| q_data_type=dtype, | ||
| kv_data_type=dtype, | ||
| ) | ||
| o = wrapper.run(q, k, v) | ||
| measurements = bench_gpu_time( | ||
| lambda: wrapper.run(q, k, v), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| enable_cupti=True, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| def flops(ms): | ||
| if causal: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9 | ||
| else: | ||
| return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9 | ||
|
|
||
| def io(ms): | ||
| mem_size = ( | ||
| q.numel() * q.element_size() | ||
| + k.numel() * k.element_size() | ||
| + v.numel() * v.element_size() | ||
| + o.numel() * o.element_size() | ||
| ) | ||
| return mem_size / ms / 1e6 | ||
|
|
||
| print( | ||
| f"bench_fmha_cutedsl (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s, io: {io(ms):.3f} GB/s" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| if not is_sm100a_supported(torch.device("cuda")): | ||
| print("Skipping: requires SM100+") | ||
| sys.exit(0) | ||
|
|
||
| configs = [ | ||
|
pgera marked this conversation as resolved.
|
||
| (128, 512, 32, 128, True, torch.bfloat16), | ||
| (64, 1024, 32, 128, True, torch.bfloat16), | ||
| (32, 2048, 32, 128, True, torch.bfloat16), | ||
| (16, 4096, 32, 128, True, torch.bfloat16), | ||
| (8, 8192, 32, 128, True, torch.bfloat16), | ||
| (4, 16384, 32, 128, True, torch.bfloat16), | ||
| (2, 32768, 32, 128, True, torch.bfloat16), | ||
| (1, 65536, 32, 128, True, torch.bfloat16), | ||
| ] | ||
|
|
||
| print("=== CUTLASS (via BatchPrefillWithRaggedKVCacheWrapper) ===") | ||
| for cfg in configs: | ||
| bench_fmha_blackwell(*cfg) | ||
| print() | ||
| print("=== CuTe DSL (via BatchPrefillCuteDSLWrapper) ===") | ||
| for cfg in configs: | ||
| bench_fmha_cutedsl(*cfg) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Modular attention kernels for CuTe DSL. | ||
|
|
||
| Kernels live at the top level of this package. | ||
| Building blocks (config, tmem_layout, roles, fusion, scheduler, wrappers) are | ||
| one level below in subdirectories. | ||
| """ | ||
|
|
||
| # Kernels | ||
| from .prefill import BlackwellFusedMultiHeadAttentionForward | ||
| from .mla_decode import BlackwellMultiLatentAttentionForward | ||
| from .mla_decode_fp8 import BlackwellMultiLatentAttentionForwardFP8 | ||
|
|
||
| # Building blocks — FMHA prefill | ||
| from .config import AttentionConfig, AttentionFusion, HeadMapping, TileBounds | ||
| from .tmem_layout import TmemLayout | ||
| from .warp_schedule import WarpSchedule, PREFILL_SCHEDULE | ||
| from .pipeline_topology import ( | ||
| PipelineEdge, | ||
| PipelineType, | ||
| PipelineTopology, | ||
| make_prefill_topology, | ||
| make_mla_topology, | ||
| make_mla_fp8_topology, | ||
| ) | ||
| from .mainloop_spec import ( | ||
| MainloopSpec, | ||
| make_prefill_mainloop_spec, | ||
| MLAMainloopSpec, | ||
| make_mla_mainloop_spec, | ||
| make_mla_fp8_mainloop_spec, | ||
| ) | ||
| from .fusion.mask import MaskType | ||
| from .fusion.variant import ( | ||
| tanh_approx, | ||
| AttentionVariant, | ||
| StandardAttention, | ||
| AttentionWithSink, | ||
| SigmoidAttention, | ||
| SigmoidTanhAttention, | ||
| ALiBiAttention, | ||
| RPEAttention, | ||
| SoftCappingAttention, | ||
| ) | ||
| from .scheduler.persistent import ( | ||
| FmhaStaticTileScheduler, | ||
| FmhaStaticTileSchedulerParams, | ||
| create_fmha_static_tile_scheduler, | ||
| create_fmha_static_tile_scheduler_params, | ||
| ) | ||
|
|
||
| # Building blocks — MLA decode | ||
| from .mla_config import MLAConfig | ||
| from .mla_warp_schedule import ( | ||
| MLAWarpSchedule, | ||
| MLA_DECODE_SCHEDULE, | ||
| MLAWarpScheduleFP8, | ||
| MLA_DECODE_FP8_SCHEDULE, | ||
| ) | ||
| from .scheduler.mla_persistent import ( | ||
| MLAStaticTileScheduler, | ||
| MLAStaticTileSchedulerParams, | ||
| create_mla_static_tile_scheduler, | ||
| create_mla_static_tile_scheduler_params, | ||
| mla_get_split_kv, | ||
| mla_get_split_kv_simplified, | ||
| mla_get_workspace_size, | ||
| ) | ||
|
|
||
| # Wrappers | ||
| from .wrappers.batch_prefill import ( | ||
| BatchPrefillCuteDSLWrapper, | ||
| ) | ||
| from .wrappers.batch_mla import ( | ||
| BatchMLADecodeCuteDSLWrapper, | ||
| cute_dsl_mla_decode, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.