feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin#3039
feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin#3039limin2021 wants to merge 28 commits intoflashinfer-ai:mainfrom
Conversation
Add cute-dsl backend support for single_prefill_with_kv_cache and BatchPrefillWithRaggedKVCacheWrapper, loading pre-compiled DSL FMHA kernels via ExternalBinaryModule. Pass through FP8 scale parameters (scale_q, scale_k, scale_v, scale_o) to the DSL kernel instead of hardcoding them as 1.0. - flashinfer/attention_dsl/cute_dsl/: New module with kernel loader (fmha.py) supporting local .so and artifactory paths, plus PyTorch API wrappers for both fixed-length and variable-length (ragged) prefill - flashinfer/prefill.py: Add "cute-dsl" backend branches in single_prefill_with_kv_cache and BatchPrefillWithRaggedKVCacheWrapper - tests/attention/test_cute_dsl_fmha_prefill.py: 81 tests covering direct API, FlashInfer integration, ragged/varlen, GQA, and cross-backend comparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add mark_layout_dynamic() for cute tensor conversion (FP8 via int8 view) - Add FP8 direct prefill and ragged prefill tests (e4m3 in → fp16 out) - Fix varlen ragged crash by padding tensors with max_seqlen (TMA overflow) - Remove unused _to_cute_tensor helper Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add enable_tvm_ffi parameter to cute_dsl_fmha_prefill and cute_dsl_fmha_ragged_prefill (default True) - TVM-FFI path: pass data_ptr() for Pointer args, torch.Tensor for Tensor args (cum_seqlen), no explicit stream (env stream) - Add _tvmffi suffix to variant names to avoid overwriting native ABI .so - Move imports to file top, use cuda_driver.CUstream for current stream - Add enable_tvm_ffi parametrize to direct/FP8/ragged FP8 tests - Add production-scale test shapes (8Kx8K, 8Kx32K, 4x1Kx80K) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Merge FP8 into main test functions via dtype parametrize (fp16/bf16/fp8) - Extract helpers: _make_qkvo, _make_ragged_qkvo, _ragged_reference, _get_tolerances - Remove duplicate test functions (9 → 6), reduce repeated code - Add ragged shapes: GQA (H_q=16 H_k=4), long context, asymmetric long KV - Asymmetric small seq_len ragged commented out (kernel TMA limitation) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The DSL FMHA kernel applies a negative offset to pointers in varlen mode (q_offset = -max_s_q * H * D), so valid GPU memory must exist before the data start. Changed from back-padding to front-padding for all dtypes (fp16/bf16/fp8), matching the DSL example's create_and_pad_tensor. This fixes: - Non-FP8 ragged tests that crashed when run individually - Asymmetric ragged (S_q != S_k) with small seq_lens - Re-enabled asymmetric test case ([32,64,16] vs [128,256,64]) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Register cute-dsl backend for BatchPrefillWithRaggedKVCacheWrapper in benchmark CC support table (SM10.0, SM10.3) - Add cute-dsl to wrapper creation and run paths in attention benchmark - Disable CUDA graph for cute-dsl (TVM-FFI env stream incompatible) - Move max seq len computation from run() to plan() in prefill wrapper to avoid D2H copy during CUDA graph capture - Add max_qo_len/max_kv_len params to cute_dsl_fmha_ragged_prefill Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a CuTe-DSL attention backend: new CuTe FMHA kernel loader and TVM/native wrappers, a ragged prefill runtime for CuTe, integration into prefill and benchmark flows, benchmark/backend mappings updated, and GPU tests exercising the new backend and ragged prefill paths. Changes
Sequence DiagramsequenceDiagram
participant App as PyTorch App
participant Prefill as Prefill API
participant Loader as DSL Kernel Loader
participant CuTe as CuTe Runtime
participant CUDA as CUDA Device
App->>Prefill: trtllm_ragged_attention_deepseek(..., backend="cute-dsl")
Prefill->>Prefill: validate params, compute scales, prepare (pad/slice) tensors
Prefill->>Loader: get_cute_dsl_fmha_kernel(dtypes, head_dim, causal, with_lse, varlen)
Loader->>Loader: check local dir or artifact, verify checksum, load module
Loader-->>Prefill: kernel callable
Prefill->>CuTe: call kernel (TVM-FFI or native iterators) with ragged indptrs
CuTe->>CUDA: launch kernel
CUDA-->>CuTe: complete
CuTe-->>Prefill: outputs written (slice off padding)
Prefill-->>App: return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces the "cute-dsl" backend for FMHA prefill kernels, specifically targeting SM10x (Blackwell) architectures. The changes include a new module for loading pre-compiled binary artifacts via ExternalBinaryModule, integration into the existing single and batch prefill APIs, and a comprehensive test suite. Feedback focuses on improving path construction safety in the artifact loader, avoiding performance-degrading GPU-CPU synchronizations during the planning phase, and addressing a known failing test case marked with a TODO.
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (1)
tests/attention/test_cute_dsl_fmha_prefill.py (1)
309-315: Pin the reference backend instead of relying onauto.These tests are supposed to compare
cute-dslagainst a different implementation, butautois a moving target. If backend selection ever resolvesautotocute-dslfor these shapes, the assertions become self-comparisons and stop validating cross-backend correctness.Suggested fix
- backend="auto", + backend="fa3",If you want a separate dispatch test for
auto, keep that as its own test and assert the resolved backend explicitly.Also applies to: 550-564
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_cute_dsl_fmha_prefill.py` around lines 309 - 315, The test currently calls flashinfer.single_prefill_with_kv_cache(..., backend="auto") which risks resolving to the same cute-dsl implementation and turning the comparison into a self-check; change the backend argument to a fixed, explicit reference backend (e.g., "reference" or the specific backend name used elsewhere) in the calls to single_prefill_with_kv_cache and any other occurrences (including the similar block around lines 550-564) so the test compares cute-dsl against a stable, non-moving implementation; if you want to keep an "auto" dispatch test, add a separate test that asserts the resolved backend explicitly before comparing outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1846-1847: The "cute-dsl" branch drops FP8 scale metadata: instead
of calling backend_wrappers[backend].run(q, k, v) it must forward q_scale,
k_scale, and v_scale so the wrapper sees FP8 scales; update the call in the
backend == "cute-dsl" branch to pass the scale variables (e.g., run(q, k, v,
q_scale, k_scale, v_scale) or as named args) and ensure
backend_wrappers["cute-dsl"].run signature accepts and uses those scale
parameters to preserve correct FP8 math for the `q`, `k`, and `v` tensors.
- Around line 1754-1761: The timer path still enables CUDA graph capture for
some backends; update the bench_gpu_time(...) call to mirror the wrapper logic
by passing use_cuda_graph only when backend not in {"fa2", "cute-dsl"}. Locate
the call site that currently passes use_cuda_graph=True for cute-dsl (near the
bench_gpu_time invocation) and change the argument to
use_cuda_graph=(is_cuda_graph_compatible if backend not in {"fa2", "cute-dsl"}
else False), keeping the same variable names and preserving behavior for other
backends and is_cuda_graph_compatible; this ensures
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper and bench_gpu_time use
the same disable-check.
In `@flashinfer/attention_dsl/cute_dsl/__init__.py`:
- Around line 21-23: The import of is_cute_dsl_available from
flashinfer.cute_dsl.utils causes eager execution of that module (and its bare
top-level import cutlass); move the is_cute_dsl_available implementation into
flashinfer.attention_dsl.cute_dsl.__init__.py and use a local function that
wraps import cutlass in a try/except returning a boolean, then replace the
top-level from flashinfer.cute_dsl.utils import is_cute_dsl_available with the
local definition and ensure any other code in this package uses this local
is_cute_dsl_available to gate CUTLASS-dependent imports or logic (referencing
the is_cute_dsl_available symbol and the module-level import gate in
__init__.py).
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 113-125: The artifact path construction can produce a leading
slash when DSL_FMHA_ARTIFACT_PATH is empty, causing FLASHINFER_CUBIN_DIR /
artifact_path to escape the cache; fix by building the artifact path with
path-safe operations: ensure you join the directory and so_filename using Path
objects or strip any leading slashes from DSL_FMHA_ARTIFACT_PATH before
combining so_filename, then assign artifact_path and compute local_path =
FLASHINFER_CUBIN_DIR / artifact_path (or directly local_path =
FLASHINFER_CUBIN_DIR / Path(DSL_FMHA_ARTIFACT_PATH) / so_filename) so
get_artifact(artifact_path, sha256) and subsequent filesystem operations never
escape FLASHINFER_CUBIN_DIR; reference symbols: DSL_FMHA_ARTIFACT_PATH,
variant_name, so_filename, artifact_path, FLASHINFER_CUBIN_DIR, local_path,
get_artifact.
In `@flashinfer/prefill.py`:
- Around line 3071-3077: The cute-dsl compute-capability check in plan()
incorrectly queries get_compute_capability(qo_indptr.device) which fails when
qo_indptr is on CPU; change it to query the actual CUDA device used for
execution (e.g., use self.device or the wrapper CUDA device object) so
get_compute_capability(...) is called on the real GPU device rather than
qo_indptr.device; update the block around get_compute_capability and the
RuntimeError message to use that device’s compute capability (retain the same
error formatting and function name get_compute_capability to locate the code).
- Around line 1324-1338: The code currently accepts 1-element torch.Tensor
scales and then calls _split_scale_param(scale_q/scale_k/scale_v), which treats
any tensor input as a tensor and returns (tensor, 1.0), effectively dropping the
scalar value; fix by detecting 1-element tensors (isinstance(..., torch.Tensor)
and s.numel() == 1) and replace them with their Python scalar (e.g., s.item() or
float(s.item())) before calling _split_scale_param so the actual scalar scale is
preserved for the cute-dsl path; update the block handling
scale_q/scale_k/scale_v (and keep the existing multi-element validation) to
perform this conversion.
In `@tests/attention/test_cute_dsl_fmha_prefill.py`:
- Around line 452-454: The test matrix includes a known-failing asymmetric case
tuple ([32, 64, 16], [128, 256, 64], 8, 8) which is causing the suite to fail;
either remove that tuple from the default parametrization or mark it as an
expected failure using pytest.param(...,
marks=pytest.mark.xfail(reason="asymmetric S_q < S_k known issue", strict=False,
reason_or_issue="<link/issue-id>")) so the suite stays green while tracking the
bug; update the parametrization where the tuples are defined to apply this
change around the ([32, 64, 16], [128, 256, 64], 8, 8) entry.
---
Nitpick comments:
In `@tests/attention/test_cute_dsl_fmha_prefill.py`:
- Around line 309-315: The test currently calls
flashinfer.single_prefill_with_kv_cache(..., backend="auto") which risks
resolving to the same cute-dsl implementation and turning the comparison into a
self-check; change the backend argument to a fixed, explicit reference backend
(e.g., "reference" or the specific backend name used elsewhere) in the calls to
single_prefill_with_kv_cache and any other occurrences (including the similar
block around lines 550-564) so the test compares cute-dsl against a stable,
non-moving implementation; if you want to keep an "auto" dispatch test, add a
separate test that asserts the resolved backend explicitly before comparing
outputs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d0393ae8-1c54-43f9-9442-98045cefd6ef
📒 Files selected for processing (7)
benchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/attention_dsl/__init__.pyflashinfer/attention_dsl/cute_dsl/__init__.pyflashinfer/attention_dsl/cute_dsl/fmha.pyflashinfer/prefill.pytests/attention/test_cute_dsl_fmha_prefill.py
|
cc @leejnau for cute dsl prefill MLA kernels |
|
If this is for DSR1 MLA prefill, please connect it to the |
…mpat - Preload cute-dsl kernel .so in plan() for fail-fast and reuse in run() - Default FP8 input to bf16 output dtype in plan() (not just run()) - Use host-side indptr for max seq len to avoid D2H during graph capture - Use self.device instead of qo_indptr.device for compute capability check - Add front-padding for cute-dsl varlen kernel in benchmark - Enable CUDA graph for cute-dsl backend in benchmark - Re-enable asymmetric (S_q < S_k) test case Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (4)
benchmarks/routines/attention.py (2)
1868-1869:⚠️ Potential issue | 🟠 MajorForward FP8 scales through the
cute-dslbenchmark call.
q,k, andvare quantized above, but this branch dropsq_scale,k_scale, andv_scale. For FP8 cases, thecute-dslbenchmark and refcheck path will run with the wrong dequantization math.Suggested fix
elif backend == "cute-dsl": - return backend_wrappers[backend].run(q, k, v) + return backend_wrappers[backend].run( + q, + k, + v, + q_scale=q_scale, + k_scale=k_scale, + v_scale=v_scale, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1868 - 1869, The cute-dsl branch in benchmarks/routines/attention.py drops FP8 dequantization scales; update the backend == "cute-dsl" branch to pass q_scale, k_scale, and v_scale through to backend_wrappers[backend].run so the cute-dsl benchmark and its refcheck receive the FP8 scale values (mirror how other backends are called), ensuring run(...) accepts and uses these additional arguments for correct dequantization math.
1776-1783:⚠️ Potential issue | 🟠 MajorKeep CUDA graphs disabled for
cute-dslin both the wrapper and timer path.The PR explicitly treats
cute-dslas non-graph-compatible, but these checks still enable graph capture for it. That leaves the benchmark exercising the TVM-FFI path you were trying to exclude.Suggested fix
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, "NHD", use_cuda_graph=is_cuda_graph_compatible - if backend not in ["fa2"] + if backend not in ["fa2", "cute-dsl"] else False, qo_indptr_buf=qo_indptr, kv_indptr_buf=kv_indptr, backend=backend, )- use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ["fa2"]), + use_cuda_graph=( + is_cuda_graph_compatible + and cur_backend not in ["fa2", "cute-dsl"] + ),Also applies to: 1954-1960
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1776 - 1783, The CUDA-graph flag is still being enabled for "cute-dsl"; update the graph-compatibility checks so CUDA graphs are disabled for both "fa2" and "cute-dsl". Concretely, change usages of is_cuda_graph_compatible in the BatchPrefillWithRaggedKVCacheWrapper call (backend_wrappers[backend]) to use is_cuda_graph_compatible if backend not in ["fa2", "cute-dsl"] else False, and make the identical change in the timer path where the timer/wrapper is constructed (the counterpart block around lines ~1954-1960) so both the wrapper and timer consistently treat "cute-dsl" as non-graph-compatible. Ensure you reference and update the same is_cuda_graph_compatible conditional in both places.flashinfer/attention_dsl/cute_dsl/fmha.py (1)
113-125:⚠️ Potential issue | 🟠 MajorPrevent
artifact_pathfrom escaping the cubin cache root.When
FLASHINFER_DSL_FMHA_ARTIFACT_PATHis unset, Line 114 becomes"/<variant>.so".FLASHINFER_CUBIN_DIR / artifact_paththen resolves outside the cache directory, so this loader can read/write the artifact from the wrong location.Suggested fix
so_filename = f"{variant_name}.so" - artifact_path = f"{DSL_FMHA_ARTIFACT_PATH}/{so_filename}" + artifact_path = ( + f"{DSL_FMHA_ARTIFACT_PATH.rstrip('/')}/{so_filename}" + if DSL_FMHA_ARTIFACT_PATH + else so_filename + ) sha256 = DSL_FMHA_CHECKSUMS.get(variant_name, "")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 113 - 125, The code builds artifact_path from DSL_FMHA_ARTIFACT_PATH and variant_name which can produce a leading slash and allow FLASHINFER_CUBIN_DIR / artifact_path to escape the cache; update the logic around artifact_path/local_path in fmha loader (referencing variant_name, artifact_path, FLASHINFER_CUBIN_DIR, FLASHINFER_DSL_FMHA_ARTIFACT_PATH, get_artifact, local_path) to ensure artifact_path is normalized to a relative path (strip any leading slashes, resolve .. segments) and then construct local_path and assert that local_path.resolve().is_relative_to(FLASHINFER_CUBIN_DIR.resolve()) (or compare prefixes) before calling get_artifact; if the check fails, raise a clear RuntimeError.flashinfer/prefill.py (1)
1324-1338:⚠️ Potential issue | 🟠 MajorNormalize FP8 scale inputs before the cute-dsl dispatch.
This branch still breaks the FP8 user path in two ways: omitted scales were already expanded to per-head tensors above and now fail this validation, while a 1-element tensor that passes validation is then turned into
1.0by_split_scale_param(). That means either an unexpectedValueErroror silently wrong scaling.Suggested fix
+ def _scalarize_scale(scale, name: str) -> float: + if scale is None: + return 1.0 + if isinstance(scale, torch.Tensor): + if scale.numel() == 1: + return float(scale.item()) + if torch.all(scale == scale.reshape(-1)[0]).item(): + return float(scale.reshape(-1)[0].item()) + raise ValueError( + f"cute-dsl backend does not support per-head scale tensors ({name}), " + "only per-tensor scalar scales are supported" + ) + return float(scale) + - if is_float8(q): - for s, name in ( - (scale_q, "scale_q"), - (scale_k, "scale_k"), - (scale_v, "scale_v"), - ): - if isinstance(s, torch.Tensor) and s.numel() > 1: - raise ValueError( - f"cute-dsl backend does not support per-head scale tensors ({name}), " - "only per-tensor scalar scales are supported" - ) - # Extract scalar scale values for DSL kernel - _, sq = _split_scale_param(scale_q) - _, sk = _split_scale_param(scale_k) - _, sv = _split_scale_param(scale_v) + sq = _scalarize_scale(scale_q, "scale_q") + sk = _scalarize_scale(scale_k, "scale_k") + sv = _scalarize_scale(scale_v, "scale_v")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1324 - 1338, Normalize FP8 scale inputs before the cute-dsl dispatch: call _split_scale_param(scale_q/scale_k/scale_v) first to extract the original param and the scalar component (e.g., _, sq = _split_scale_param(scale_q)) and then perform the per-head tensor validation using the original param (the first return) rather than the possibly collapsed value; this ensures omitted scales that were expanded earlier are normalized to scalars and 1-element tensors are treated as scalars for the DSL kernel while still raising ValueError for true per-head tensors.
🧹 Nitpick comments (1)
flashinfer/attention_dsl/cute_dsl/fmha.py (1)
466-468: Drop the unusedtotal_kvbinding.
total_kvis never read, so this now fails Ruff and adds noise in the wrapper.Suggested fix
- total_q, H_q, D = q.shape - total_kv, H_k, _ = k.shape + total_q, H_q, D = q.shape + _, H_k, _ = k.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 466 - 468, The variable total_kv is unused and triggers linter noise; update the k.shape unpacking in fmha.py to ignore that element (e.g., replace total_kv with _ in the tuple assignment) so only the needed H_k (and any used dimensions) are bound, removing the unused binding from the wrapper logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/prefill.py`:
- Around line 3072-3093: When selecting the "cute-dsl" backend in the prefill
plan, reject wrapper layouts using HND to avoid silently feeding misordered
tensors into cute_dsl_fmha_ragged_prefill; add a check alongside the other
backend validations (the block that checks get_compute_capability,
pos_encoding_mode, packed_custom_mask, logits_soft_cap, use_fp16_qk_reduction)
to raise a ValueError if self._kv_layout (or the local kv_layout variable) ==
"HND", referencing self._kv_layout and the cute_dsl_fmha_ragged_prefill behavior
that ignores the wrapper layout.
- Around line 3334-3355: The DSL ragged kernel expects buffers to be
front-padded because it uses negative pointer offsets, but the current call to
cute_dsl_fmha_ragged_prefill forwards user tensors that start at storage offset
0; to fix, allocate padded versions of q, k, v, and out with extra prefix space
(pad length = max_len or total ragged padding as the tests do), copy the
original tensors into the tail slice of those padded tensors, adjust any
indptr/pointer buffers if needed, and pass these padded/tail-sliced tensors to
cute_dsl_fmha_ragged_prefill instead of the original q/k/v/out; reference the
call site around cute_dsl_fmha_ragged_prefill and the buffers
_qo_indptr_buf/_kv_indptr_buf to locate where to insert the padding/wrapping
logic.
In `@tests/attention/test_cute_dsl_fmha_prefill.py`:
- Around line 20-37: The skip condition currently blocks all SM10x devices by
calling is_sm100a_supported; replace that check with a function that accepts the
whole SM10x family (e.g., is_sm10x_supported) so SM10.3 is allowed: in the
pytestmark list change the second skipif predicate from "not
torch.cuda.is_available() or not is_sm100a_supported(torch.device('cuda'))" to
"not torch.cuda.is_available() or not is_sm10x_supported(torch.device('cuda'))"
(or, if such helper doesn't exist, implement a small helper in flashinfer.utils
that checks CUDA compute capability major==10 and call it here). Ensure you
reference the existing symbol is_sm100a_supported in the change so reviewers can
locate and replace it with is_sm10x_supported (or the new helper).
---
Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1868-1869: The cute-dsl branch in benchmarks/routines/attention.py
drops FP8 dequantization scales; update the backend == "cute-dsl" branch to pass
q_scale, k_scale, and v_scale through to backend_wrappers[backend].run so the
cute-dsl benchmark and its refcheck receive the FP8 scale values (mirror how
other backends are called), ensuring run(...) accepts and uses these additional
arguments for correct dequantization math.
- Around line 1776-1783: The CUDA-graph flag is still being enabled for
"cute-dsl"; update the graph-compatibility checks so CUDA graphs are disabled
for both "fa2" and "cute-dsl". Concretely, change usages of
is_cuda_graph_compatible in the BatchPrefillWithRaggedKVCacheWrapper call
(backend_wrappers[backend]) to use is_cuda_graph_compatible if backend not in
["fa2", "cute-dsl"] else False, and make the identical change in the timer path
where the timer/wrapper is constructed (the counterpart block around lines
~1954-1960) so both the wrapper and timer consistently treat "cute-dsl" as
non-graph-compatible. Ensure you reference and update the same
is_cuda_graph_compatible conditional in both places.
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 113-125: The code builds artifact_path from DSL_FMHA_ARTIFACT_PATH
and variant_name which can produce a leading slash and allow
FLASHINFER_CUBIN_DIR / artifact_path to escape the cache; update the logic
around artifact_path/local_path in fmha loader (referencing variant_name,
artifact_path, FLASHINFER_CUBIN_DIR, FLASHINFER_DSL_FMHA_ARTIFACT_PATH,
get_artifact, local_path) to ensure artifact_path is normalized to a relative
path (strip any leading slashes, resolve .. segments) and then construct
local_path and assert that
local_path.resolve().is_relative_to(FLASHINFER_CUBIN_DIR.resolve()) (or compare
prefixes) before calling get_artifact; if the check fails, raise a clear
RuntimeError.
In `@flashinfer/prefill.py`:
- Around line 1324-1338: Normalize FP8 scale inputs before the cute-dsl
dispatch: call _split_scale_param(scale_q/scale_k/scale_v) first to extract the
original param and the scalar component (e.g., _, sq =
_split_scale_param(scale_q)) and then perform the per-head tensor validation
using the original param (the first return) rather than the possibly collapsed
value; this ensures omitted scales that were expanded earlier are normalized to
scalars and 1-element tensors are treated as scalars for the DSL kernel while
still raising ValueError for true per-head tensors.
---
Nitpick comments:
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 466-468: The variable total_kv is unused and triggers linter
noise; update the k.shape unpacking in fmha.py to ignore that element (e.g.,
replace total_kv with _ in the tuple assignment) so only the needed H_k (and any
used dimensions) are bound, removing the unused binding from the wrapper logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: dc23d733-dca0-4057-b2b1-e199e8144a8a
📒 Files selected for processing (4)
benchmarks/routines/attention.pyflashinfer/attention_dsl/cute_dsl/fmha.pyflashinfer/prefill.pytests/attention/test_cute_dsl_fmha_prefill.py
…E and benchmark support - Add `backend` param to trtllm_ragged_attention_deepseek to dispatch cute-dsl - Add `with_lse` to DSL kernel variant selection for correct .so loading - Front-pad output tensor for DSL varlen negative pointer offsets - Add cute-dsl backend parametrize to test_trtllm_gen_prefill tests - Route benchmark cute-dsl through trtllm_ragged_attention_deepseek directly Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (5)
flashinfer/attention_dsl/cute_dsl/fmha.py (2)
523-550:⚠️ Potential issue | 🔴 CriticalFront-pad ragged buffers inside this helper before launching the varlen kernel.
The DSL ragged kernel uses negative pointer offsets, but this wrapper still forwards the user views directly. The current tests and benchmarks only stay safe because they pass slices from front-padded allocations; a normal contiguous
q/k/v/otensor here can underflow the allocation on the first access.Also applies to: 554-631
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 523 - 550, The wrapper is passing user views q/k/v/o directly to kernel_fn which expects ragged buffers with front-padding because the DSL kernel uses negative pointer offsets; fix by creating front-padded temporary tensors for q, k, v, o (e.g., pad_front_q = torch.empty(front_pad + q.size(0), ...); copy q into the padded region) before calling kernel_fn, and pass their .data_ptr() and the padded tensors (like q_4d) to the kernel invocation; update both the enable_tvm_ffi branch around kernel_fn (the block using q_4d/qo_indptr/etc.) and the similar branch later (lines 554-631) to ensure all varlen launches use front-padded buffers and preserve device/dtype/contiguity and stream detection.
115-127:⚠️ Potential issue | 🟠 MajorKeep
artifact_pathrelative toFLASHINFER_CUBIN_DIR.When
FLASHINFER_DSL_FMHA_ARTIFACT_PATHis unset, this builds"/<variant>.so".FLASHINFER_CUBIN_DIR / artifact_paththen ignores the cache root and reads/writes outside the intended cubin directory.Suggested fix
- artifact_path = f"{DSL_FMHA_ARTIFACT_PATH}/{so_filename}" + artifact_path = ( + f"{DSL_FMHA_ARTIFACT_PATH.rstrip('/')}/{so_filename}" + if DSL_FMHA_ARTIFACT_PATH + else so_filename + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 115 - 127, artifact_path may start with a leading slash when DSL_FMHA_ARTIFACT_PATH is empty, causing FLASHINFER_CUBIN_DIR / artifact_path to ignore the cache root; fix by ensuring artifact_path is always relative: build it by joining DSL_FMHA_ARTIFACT_PATH and so_filename while stripping any leading/trailing slashes or falling back to just so_filename (e.g., compute base = DSL_FMHA_ARTIFACT_PATH.strip("/") and set artifact_path = f"{base}/{so_filename}" if base else so_filename), then use local_path = FLASHINFER_CUBIN_DIR / artifact_path and pass artifact_path to get_artifact.flashinfer/prefill.py (2)
1324-1338:⚠️ Potential issue | 🟠 MajorCute-dsl FP8 scale handling still breaks the scalar/default case.
By the time this branch runs, omitted FP8 scales were already materialized as per-head tensors, so
backend="cute-dsl"rejects the default1.0case. A 1-element tensor also passes this guard but_split_scale_param()converts it back to1.0, dropping the caller’s actual scalar.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 1324 - 1338, The cute-dsl per-head scale guard currently rejects 1-element tensors because it runs before converting scales to scalar values; change the logic to extract scalar scale values first (call _split_scale_param for scale_q, scale_k, scale_v) and then enforce the per-head restriction only when a scale is a tensor with numel() > 1, or alternatively detect and treat 1-element tensors as scalars before the is_float8/cute-dsl check so that 1-element/materialized default scales are accepted instead of raising in the block that references is_float8, scale_q/scale_k/scale_v and _split_scale_param.
3072-3093:⚠️ Potential issue | 🟠 MajorReject
kv_layout="HND"for ragged cute-dsl plans.
run()forwards rawk/vtensors intocute_dsl_fmha_ragged_prefill()and never reorders them for the wrapper layout. A cute-dsl wrapper planned withHNDwill therefore feed misordered tensors to the kernel.Suggested fix
if self._backend == "cute-dsl": from .utils import get_compute_capability cc = get_compute_capability(self.device) if cc[0] != 10: raise RuntimeError( f"cute-dsl backend (FMHA prefill kernel) requires SM10x (Blackwell), got SM{cc[0]}{cc[1]}" ) + if self._kv_layout != "NHD": + raise ValueError("cute-dsl backend only supports NHD layout") if pos_encoding_mode != "NONE": raise ValueError( f"cute-dsl backend does not support pos_encoding_mode={pos_encoding_mode}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/prefill.py` around lines 3072 - 3093, The cute-dsl branch must reject wrapper plans that use kv_layout="HND" for ragged prefill because run() forwards raw k/v into cute_dsl_fmha_ragged_prefill() without reordering; in the cute-dsl handling block (the code that checks self._backend == "cute-dsl") add a guard that checks the kv_layout variable (and the plan ragged flag, e.g., plan.ragged or is_ragged) and raise a ValueError if kv_layout == "HND" and the plan is ragged, with a clear message like "cute-dsl ragged plans do not support kv_layout='HND'".benchmarks/routines/attention.py (1)
1981-2006:⚠️ Potential issue | 🟠 MajorDisable CUDA graph capture in both
cute-dsltimer paths.These benchmark calls still pass
use_cuda_graph=Trueforcute-dsl, so non-CUPTI runs can graph-capture the TVM-FFI path you were trying to exclude.Suggested fix
- use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ["fa2"]), + use_cuda_graph=( + is_cuda_graph_compatible + and cur_backend not in {"fa2", "cute-dsl"} + ),Apply the same change to the MLA
bench_gpu_time(...)call as well.Also applies to: 2443-2463
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 1981 - 2006, The bench_gpu_time calls for the TVM/FFI ("cute-dsl") backends still allow CUDA graph capture; update the use_cuda_graph argument in the bench_gpu_time invocation(s) (e.g., the call with fn=run_backend_wrapper using cur_backend, and the MLA bench_gpu_time call) so that use_cuda_graph is False when cur_backend == "cute-dsl" (e.g., use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ["fa2","cute-dsl"])). Ensure both occurrences (the shown run_backend_wrapper call and the MLA call around lines 2443-2463) are changed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 303-306: The code always loads the non-LSE cubin because
get_cute_dsl_fmha_kernel is invoked without selecting the LSE symbol; update the
call site in the kernel acquisition (kernel_fn = get_cute_dsl_fmha_kernel(...))
to request the fixed-length LSE variant when LSE/return_lse is requested (e.g.,
pass the lse/return_lse flag or construct/choose the symbol name with the "_lse"
suffix) so that the returned kernel_fn points to the correct LSE-enabled symbol
rather than the default non-LSE cubin.
In `@flashinfer/prefill.py`:
- Around line 3094-3105: The preload in plan() always sets self._cached_module
using get_cute_dsl_fmha_kernel without with_lse, which forces
run(return_lse=True) to reuse the non-LSE cubin; change the caching so you store
kernels keyed by the with_lse flag (e.g., a dict on self like
self._cached_modules[(q_data_type,o_data_type,head_dim_qk,causal,with_lse)]) and
call get_cute_dsl_fmha_kernel with with_lse=True when run requests
return_lse=True, ensuring run() looks up the correct cached variant instead of
always using self._cached_module.
---
Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1981-2006: The bench_gpu_time calls for the TVM/FFI ("cute-dsl")
backends still allow CUDA graph capture; update the use_cuda_graph argument in
the bench_gpu_time invocation(s) (e.g., the call with fn=run_backend_wrapper
using cur_backend, and the MLA bench_gpu_time call) so that use_cuda_graph is
False when cur_backend == "cute-dsl" (e.g.,
use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in
["fa2","cute-dsl"])). Ensure both occurrences (the shown run_backend_wrapper
call and the MLA call around lines 2443-2463) are changed.
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 523-550: The wrapper is passing user views q/k/v/o directly to
kernel_fn which expects ragged buffers with front-padding because the DSL kernel
uses negative pointer offsets; fix by creating front-padded temporary tensors
for q, k, v, o (e.g., pad_front_q = torch.empty(front_pad + q.size(0), ...);
copy q into the padded region) before calling kernel_fn, and pass their
.data_ptr() and the padded tensors (like q_4d) to the kernel invocation; update
both the enable_tvm_ffi branch around kernel_fn (the block using
q_4d/qo_indptr/etc.) and the similar branch later (lines 554-631) to ensure all
varlen launches use front-padded buffers and preserve device/dtype/contiguity
and stream detection.
- Around line 115-127: artifact_path may start with a leading slash when
DSL_FMHA_ARTIFACT_PATH is empty, causing FLASHINFER_CUBIN_DIR / artifact_path to
ignore the cache root; fix by ensuring artifact_path is always relative: build
it by joining DSL_FMHA_ARTIFACT_PATH and so_filename while stripping any
leading/trailing slashes or falling back to just so_filename (e.g., compute base
= DSL_FMHA_ARTIFACT_PATH.strip("/") and set artifact_path =
f"{base}/{so_filename}" if base else so_filename), then use local_path =
FLASHINFER_CUBIN_DIR / artifact_path and pass artifact_path to get_artifact.
In `@flashinfer/prefill.py`:
- Around line 1324-1338: The cute-dsl per-head scale guard currently rejects
1-element tensors because it runs before converting scales to scalar values;
change the logic to extract scalar scale values first (call _split_scale_param
for scale_q, scale_k, scale_v) and then enforce the per-head restriction only
when a scale is a tensor with numel() > 1, or alternatively detect and treat
1-element tensors as scalars before the is_float8/cute-dsl check so that
1-element/materialized default scales are accepted instead of raising in the
block that references is_float8, scale_q/scale_k/scale_v and _split_scale_param.
- Around line 3072-3093: The cute-dsl branch must reject wrapper plans that use
kv_layout="HND" for ragged prefill because run() forwards raw k/v into
cute_dsl_fmha_ragged_prefill() without reordering; in the cute-dsl handling
block (the code that checks self._backend == "cute-dsl") add a guard that checks
the kv_layout variable (and the plan ragged flag, e.g., plan.ragged or
is_ragged) and raise a ValueError if kv_layout == "HND" and the plan is ragged,
with a clear message like "cute-dsl ragged plans do not support
kv_layout='HND'".
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f6d09c2b-8008-4f12-8f9c-a7f7e71e9dbb
📒 Files selected for processing (4)
benchmarks/routines/attention.pyflashinfer/attention_dsl/cute_dsl/fmha.pyflashinfer/prefill.pytests/attention/test_trtllm_gen_attention.py
|
@limin2021 please share some perf numbers if possible. thanks! |
- Update DSL_FMHA artifact path to latest CI build (b0adf88) - Add aarch64 checksums for sm_100a, sm_103a, sm_110a - Make artifact paths and checksums arch-aware (cpu_arch/sm_arch) - Fix varlen kernel to always use non-persistent mode Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…_dsl to attention/cute_dsl - attention.py → attention/_core.py (re-exported via attention/__init__.py) - attention_dsl/cute_dsl/ → attention/cute_dsl/ (consistent with mla/cute_dsl/, fused_moe/cute_dsl/) - Update relative imports in _core.py (. → ..) - Update import paths in prefill.py and test_cute_dsl_fmha_prefill.py Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove bf16/fp16 from parametrize (FP8 suffices for cubin load validation) - Update docstring to reflect artifactory-first usage Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…quirement - Reuse _get_host_cpu_arch from artifacts.py instead of duplicating in fmha.py - Add clarifying comment on DSL_FMHA_CHECKSUMS (manifest hash, not kernel hash) - Document front-padding requirement in cute_dsl_fmha_ragged_prefill docstring - Add backend parameter docs in trtllm_ragged_attention_deepseek with front-padding note Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…_dim_qk - Assert query dtype is fp16/bf16/fp8_e4m3fn in cute-dsl path - Remove duplicate head_dim_qk assignment in test_trtllm_gen_prefill Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add test_trtllm_gen_prefill_fp8 with DeepSeek-R1 config (H=128, 8K seqlen) - Test both mla_dimensions (h192/h128), causal/non-causal, skip-softmax - Remove standalone test_cute_dsl_fmha_prefill.py (moved to fmha-cubin-integration) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
/bot run |
…attention_deepseek Aligns with the naming convention used across other flashinfer functions, since this codepath uses the trtllm-gen module. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/bot run |
|
/bot run |
Update DSL_FMHA path hash from b0adf88 to c770c91c to point at the latest cubin release on artifactory, and refresh all 6 checksums.txt SHA256 hashes (x86_64/aarch64 × sm_100a/sm_103a/sm_110a). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/bot run |
The test was not updated when DSL_FMHA artifact subdirectories were added to get_subdir_file_list. Without the new mocks, the test tried to download checksums.txt from real URLs and failed with FileNotFoundError. Pin cpu_arch to x86_64 for deterministic mocks regardless of the runner architecture. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/bot run |
Previously _get_gpu_arch() was called inside get_cute_dsl_fmha_kernel and relied on the current default CUDA device. With @functools.cache the arch was frozen at first call, so on heterogeneous multi-GPU nodes subsequent calls on a different-arch device would silently reuse the wrong cubin. Align with the pattern used by get_fp4_quantization_module: caller computes arch from the tensor's device and passes it as a parameter, making arch part of the cache key. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📌 Description
feat: Integrate CuTe DSL FMHA cubin kernels into FlashInfer prefill backend
Summary
Key features
Files changed
Test plan
Performance
Setup: B200 (sm_100a), causal, H_q=H_k=128, tested using FI benchmark (CUDA Graph, cupti)
FP8 e4m3 (D=192):
FP8 e4m3 (D=128):
BF16 (D=128):
TODO
(1) support scalar as tensor dtype.
(2) support pdl
(3) remove front-padding for q/k/v/o tensors
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes