Skip to content

feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin#3039

Open
limin2021 wants to merge 28 commits intoflashinfer-ai:mainfrom
limin2021:integrate_dsl_cubin_fmha
Open

feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin#3039
limin2021 wants to merge 28 commits intoflashinfer-ai:mainfrom
limin2021:integrate_dsl_cubin_fmha

Conversation

@limin2021
Copy link
Copy Markdown
Contributor

@limin2021 limin2021 commented Apr 12, 2026

📌 Description

feat: Integrate CuTe DSL FMHA cubin kernels into FlashInfer prefill backend

Summary

  • Integrate pre-compiled CuTe DSL FMHA kernels (Blackwell SM100/SM103/SM110) into FlashInfer's prefill attention backend
  • Load AOT-compiled .so cubins from NVIDIA artifactory at runtime, no JIT compilation needed
  • Route through trtllm_ragged_attention_deepseek() API with backend="cute-dsl"

Key features

  • Dtype support: FP16, BF16, FP8 (E4M3) input with mixed-precision output (E4M3→BF16)
  • Head dimensions: 32, 64, 128, 192 (192 for FP8 only)
  • Varlen ragged prefill: variable-length sequences via cumulative seqlen tensors
  • TVM-FFI ABI: all variants use TVM-FFI for kernel invocation
  • Skip-softmax sparsity: optional skip-softmax optimization for sparse attention
  • LSE output: optional log-sum-exp output for numerically stable multi-pass attention
  • Causal & non-causal masking: both modes supported (all varlen variants use non-persistent scheduling)
  • Multi-arch cubin loading: per-CPU-arch (x86_64/aarch64) and per-SM-arch artifact paths
  • Checksum verification: SHA256 integrity check on downloaded .so files

Files changed

  • flashinfer/attention_dsl/cute_dsl/fmha.py — kernel loading, variant selection, ragged prefill entry point
  • flashinfer/artifacts.py — artifact paths and checksums for DSL FMHA (x86_64 + aarch64 layout)
  • flashinfer/prefill.py — trtllm_ragged_attention_deepseek() cute-dsl backend integration
  • tests/attention/test_cute_dsl_fmha_prefill.py — correctness tests against PyTorch reference

Test plan

  • test_cute_dsl_fmha_prefill.py passes
  • test_trtllm_gen_attention.py::test_trtllm_gen_prefill -k "cute-dsl" passes
  • Benchmark via bench_cute_dsl_ragged.sh on target hardware
  • Verify cubin download + checksum verification on clean install

Performance

Setup: B200 (sm_100a), causal, H_q=H_k=128, tested using FI benchmark (CUDA Graph, cupti)

FP8 e4m3 (D=192):

Shape (B×S_q×S_kv) cute-dsl (ms) trtllm-native (ms) TFLOPS (dsl/native) Speedup
1×8K×8K 1.521 1.619 1808 / 1698 +6.4%
1×8K×32K 8.466 9.451 2273 / 2036 +11.6%
1×8K×64K 17.796 19.869 2317 / 2075 +11.7%
4×512×82K 6.397 7.286 2142 / 1880 +13.9%
4×1K×82K 12.285 13.834 2224 / 1975 +12.6%

FP8 e4m3 (D=128):

Shape (B×S_q×S_kv) cute-dsl (ms) trtllm-native (ms) TFLOPS (dsl/native) Speedup
1×8K×8K 1.484 1.560 1481 / 1410 +5.1%
1×8K×32K 7.666 8.998 2008 / 1711 +17.4%
1×8K×64K 16.074 18.606 2052 / 1773 +15.8%
4×512×82K 5.735 6.460 1911 / 1697 +12.6%
4×1K×82K 11.066 12.451 1975 / 1755 +12.5%

BF16 (D=128):

Shape (B×S_q×S_kv) cute-dsl (ms) trtllm-native (ms) TFLOPS (dsl/native) Speedup
1×8K×8K 1.737 1.764 1266 / 1247 +1.6%
1×8K×32K 10.094 10.992 1525 / 1400 +8.9%
1×8K×64K 21.745 23.000 1517 / 1434 +5.8%
4×512×82K 8.457 8.513 1296 / 1288 +0.7%
4×1K×82K 15.773 16.052 1385 / 1361 +1.8%

TODO
(1) support scalar as tensor dtype.
(2) support pdl
(3) remove front-padding for q/k/v/o tensors

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

limin2021 and others added 6 commits April 12, 2026 00:15
Add cute-dsl backend support for single_prefill_with_kv_cache and
BatchPrefillWithRaggedKVCacheWrapper, loading pre-compiled DSL FMHA
kernels via ExternalBinaryModule. Pass through FP8 scale parameters
(scale_q, scale_k, scale_v, scale_o) to the DSL kernel instead of
hardcoding them as 1.0.

- flashinfer/attention_dsl/cute_dsl/: New module with kernel loader
  (fmha.py) supporting local .so and artifactory paths, plus PyTorch
  API wrappers for both fixed-length and variable-length (ragged) prefill
- flashinfer/prefill.py: Add "cute-dsl" backend branches in
  single_prefill_with_kv_cache and BatchPrefillWithRaggedKVCacheWrapper
- tests/attention/test_cute_dsl_fmha_prefill.py: 81 tests covering
  direct API, FlashInfer integration, ragged/varlen, GQA, and
  cross-backend comparison

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Add mark_layout_dynamic() for cute tensor conversion (FP8 via int8 view)
- Add FP8 direct prefill and ragged prefill tests (e4m3 in → fp16 out)
- Fix varlen ragged crash by padding tensors with max_seqlen (TMA overflow)
- Remove unused _to_cute_tensor helper

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add enable_tvm_ffi parameter to cute_dsl_fmha_prefill and
  cute_dsl_fmha_ragged_prefill (default True)
- TVM-FFI path: pass data_ptr() for Pointer args, torch.Tensor for
  Tensor args (cum_seqlen), no explicit stream (env stream)
- Add _tvmffi suffix to variant names to avoid overwriting native ABI .so
- Move imports to file top, use cuda_driver.CUstream for current stream
- Add enable_tvm_ffi parametrize to direct/FP8/ragged FP8 tests
- Add production-scale test shapes (8Kx8K, 8Kx32K, 4x1Kx80K)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Merge FP8 into main test functions via dtype parametrize (fp16/bf16/fp8)
- Extract helpers: _make_qkvo, _make_ragged_qkvo, _ragged_reference, _get_tolerances
- Remove duplicate test functions (9 → 6), reduce repeated code
- Add ragged shapes: GQA (H_q=16 H_k=4), long context, asymmetric long KV
- Asymmetric small seq_len ragged commented out (kernel TMA limitation)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The DSL FMHA kernel applies a negative offset to pointers in varlen mode
(q_offset = -max_s_q * H * D), so valid GPU memory must exist before
the data start. Changed from back-padding to front-padding for all
dtypes (fp16/bf16/fp8), matching the DSL example's create_and_pad_tensor.

This fixes:
- Non-FP8 ragged tests that crashed when run individually
- Asymmetric ragged (S_q != S_k) with small seq_lens
- Re-enabled asymmetric test case ([32,64,16] vs [128,256,64])

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Register cute-dsl backend for BatchPrefillWithRaggedKVCacheWrapper
  in benchmark CC support table (SM10.0, SM10.3)
- Add cute-dsl to wrapper creation and run paths in attention benchmark
- Disable CUDA graph for cute-dsl (TVM-FFI env stream incompatible)
- Move max seq len computation from run() to plan() in prefill wrapper
  to avoid D2H copy during CUDA graph capture
- Add max_qo_len/max_kv_len params to cute_dsl_fmha_ragged_prefill

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 12, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a CuTe-DSL attention backend: new CuTe FMHA kernel loader and TVM/native wrappers, a ragged prefill runtime for CuTe, integration into prefill and benchmark flows, benchmark/backend mappings updated, and GPU tests exercising the new backend and ragged prefill paths.

Changes

Cohort / File(s) Summary
Benchmarking & Wrapper Logic
benchmarks/routines/attention.py, benchmarks/flashinfer_benchmark_utils.py
Permit cute-dsl in supported backends and backend lists; add front-padding/slicing for Q/K/V and outputs when using cute-dsl; adjust CUDA-graph enablement and wrapper instantiation conditions.
Attention DSL Package Surface
flashinfer/attention_dsl/__init__.py, flashinfer/attention_dsl/cute_dsl/__init__.py
Add attention_dsl package and cute_dsl subpackage; conditional exports based on CuTe-DSL availability.
CuTe-DSL FMHA Implementation
flashinfer/attention_dsl/cute_dsl/fmha.py
New module: variant naming, artifact/local binary lookup and checksum validation, memoized kernel retrieval, and runtime API cute_dsl_fmha_ragged_prefill (TVM-FFI and native CuTe-iterator invocation, FP8 handling, varlen metadata).
Prefill Integration & API
flashinfer/prefill.py
Add backend: str param to trtllm_ragged_attention_deepseek; new cute-dsl execution branch that adapts scale arguments and calls CuTe ragged prefill; update FP8 output dtype inference and out-dtype caching.
GPU Tests & Test Harness
tests/attention/test_cute_dsl_fmha_prefill.py, tests/attention/test_trtllm_gen_attention.py
Add CuTe-DSL FMHA ragged prefill tests (device-only, local-artifact gating); parametrize trtllm tests over backend, add cute-dsl front-padding/slicing for inputs/outputs, and relax workspace-zero assertion for non-TRT backends.

Sequence Diagram

sequenceDiagram
    participant App as PyTorch App
    participant Prefill as Prefill API
    participant Loader as DSL Kernel Loader
    participant CuTe as CuTe Runtime
    participant CUDA as CUDA Device

    App->>Prefill: trtllm_ragged_attention_deepseek(..., backend="cute-dsl")
    Prefill->>Prefill: validate params, compute scales, prepare (pad/slice) tensors
    Prefill->>Loader: get_cute_dsl_fmha_kernel(dtypes, head_dim, causal, with_lse, varlen)
    Loader->>Loader: check local dir or artifact, verify checksum, load module
    Loader-->>Prefill: kernel callable
    Prefill->>CuTe: call kernel (TVM-FFI or native iterators) with ragged indptrs
    CuTe->>CUDA: launch kernel
    CUDA-->>CuTe: complete
    CuTe-->>Prefill: outputs written (slice off padding)
    Prefill-->>App: return outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

ready, benchmark

Suggested reviewers

  • cyx-6
  • nvmbreughe
  • jimmyzho
  • jiahanc
  • kahyunnam
  • yzh119
  • bkryu
  • joker-eph

Poem

🐰 I found a tiny cubin bright,
kernels hum in midnight light,
ragged lengths I pad and slice,
scales aligned and kernels nice,
hop—prefill done, all set for flight!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title 'feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin' accurately summarizes the main change: integrating pre-compiled CuTe DSL FMHA kernels as a backend for prefill, with emphasis on kernel loading from cubin files.
Description check ✅ Passed The pull request description is comprehensive and well-structured, addressing the template requirements with detailed summary, key features, files changed, test plan, performance data, and checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@limin2021 limin2021 changed the title Integrate cute dsl fmha (cubin) Title: feat: Integrate CuTe DSL FMHA cubin kernels into prefill backend Apr 12, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the "cute-dsl" backend for FMHA prefill kernels, specifically targeting SM10x (Blackwell) architectures. The changes include a new module for loading pre-compiled binary artifacts via ExternalBinaryModule, integration into the existing single and batch prefill APIs, and a comprehensive test suite. Feedback focuses on improving path construction safety in the artifact loader, avoiding performance-degrading GPU-CPU synchronizations during the planning phase, and addressing a known failing test case marked with a TODO.

Comment thread flashinfer/attention_dsl/cute_dsl/fmha.py Outdated
Comment thread flashinfer/prefill.py Outdated
Comment thread tests/attention/test_cute_dsl_fmha_prefill.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (1)
tests/attention/test_cute_dsl_fmha_prefill.py (1)

309-315: Pin the reference backend instead of relying on auto.

These tests are supposed to compare cute-dsl against a different implementation, but auto is a moving target. If backend selection ever resolves auto to cute-dsl for these shapes, the assertions become self-comparisons and stop validating cross-backend correctness.

Suggested fix
-        backend="auto",
+        backend="fa3",

If you want a separate dispatch test for auto, keep that as its own test and assert the resolved backend explicitly.

Also applies to: 550-564

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_cute_dsl_fmha_prefill.py` around lines 309 - 315, The
test currently calls flashinfer.single_prefill_with_kv_cache(...,
backend="auto") which risks resolving to the same cute-dsl implementation and
turning the comparison into a self-check; change the backend argument to a
fixed, explicit reference backend (e.g., "reference" or the specific backend
name used elsewhere) in the calls to single_prefill_with_kv_cache and any other
occurrences (including the similar block around lines 550-564) so the test
compares cute-dsl against a stable, non-moving implementation; if you want to
keep an "auto" dispatch test, add a separate test that asserts the resolved
backend explicitly before comparing outputs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 1846-1847: The "cute-dsl" branch drops FP8 scale metadata: instead
of calling backend_wrappers[backend].run(q, k, v) it must forward q_scale,
k_scale, and v_scale so the wrapper sees FP8 scales; update the call in the
backend == "cute-dsl" branch to pass the scale variables (e.g., run(q, k, v,
q_scale, k_scale, v_scale) or as named args) and ensure
backend_wrappers["cute-dsl"].run signature accepts and uses those scale
parameters to preserve correct FP8 math for the `q`, `k`, and `v` tensors.
- Around line 1754-1761: The timer path still enables CUDA graph capture for
some backends; update the bench_gpu_time(...) call to mirror the wrapper logic
by passing use_cuda_graph only when backend not in {"fa2", "cute-dsl"}. Locate
the call site that currently passes use_cuda_graph=True for cute-dsl (near the
bench_gpu_time invocation) and change the argument to
use_cuda_graph=(is_cuda_graph_compatible if backend not in {"fa2", "cute-dsl"}
else False), keeping the same variable names and preserving behavior for other
backends and is_cuda_graph_compatible; this ensures
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper and bench_gpu_time use
the same disable-check.

In `@flashinfer/attention_dsl/cute_dsl/__init__.py`:
- Around line 21-23: The import of is_cute_dsl_available from
flashinfer.cute_dsl.utils causes eager execution of that module (and its bare
top-level import cutlass); move the is_cute_dsl_available implementation into
flashinfer.attention_dsl.cute_dsl.__init__.py and use a local function that
wraps import cutlass in a try/except returning a boolean, then replace the
top-level from flashinfer.cute_dsl.utils import is_cute_dsl_available with the
local definition and ensure any other code in this package uses this local
is_cute_dsl_available to gate CUTLASS-dependent imports or logic (referencing
the is_cute_dsl_available symbol and the module-level import gate in
__init__.py).

In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 113-125: The artifact path construction can produce a leading
slash when DSL_FMHA_ARTIFACT_PATH is empty, causing FLASHINFER_CUBIN_DIR /
artifact_path to escape the cache; fix by building the artifact path with
path-safe operations: ensure you join the directory and so_filename using Path
objects or strip any leading slashes from DSL_FMHA_ARTIFACT_PATH before
combining so_filename, then assign artifact_path and compute local_path =
FLASHINFER_CUBIN_DIR / artifact_path (or directly local_path =
FLASHINFER_CUBIN_DIR / Path(DSL_FMHA_ARTIFACT_PATH) / so_filename) so
get_artifact(artifact_path, sha256) and subsequent filesystem operations never
escape FLASHINFER_CUBIN_DIR; reference symbols: DSL_FMHA_ARTIFACT_PATH,
variant_name, so_filename, artifact_path, FLASHINFER_CUBIN_DIR, local_path,
get_artifact.

In `@flashinfer/prefill.py`:
- Around line 3071-3077: The cute-dsl compute-capability check in plan()
incorrectly queries get_compute_capability(qo_indptr.device) which fails when
qo_indptr is on CPU; change it to query the actual CUDA device used for
execution (e.g., use self.device or the wrapper CUDA device object) so
get_compute_capability(...) is called on the real GPU device rather than
qo_indptr.device; update the block around get_compute_capability and the
RuntimeError message to use that device’s compute capability (retain the same
error formatting and function name get_compute_capability to locate the code).
- Around line 1324-1338: The code currently accepts 1-element torch.Tensor
scales and then calls _split_scale_param(scale_q/scale_k/scale_v), which treats
any tensor input as a tensor and returns (tensor, 1.0), effectively dropping the
scalar value; fix by detecting 1-element tensors (isinstance(..., torch.Tensor)
and s.numel() == 1) and replace them with their Python scalar (e.g., s.item() or
float(s.item())) before calling _split_scale_param so the actual scalar scale is
preserved for the cute-dsl path; update the block handling
scale_q/scale_k/scale_v (and keep the existing multi-element validation) to
perform this conversion.

In `@tests/attention/test_cute_dsl_fmha_prefill.py`:
- Around line 452-454: The test matrix includes a known-failing asymmetric case
tuple ([32, 64, 16], [128, 256, 64], 8, 8) which is causing the suite to fail;
either remove that tuple from the default parametrization or mark it as an
expected failure using pytest.param(...,
marks=pytest.mark.xfail(reason="asymmetric S_q < S_k known issue", strict=False,
reason_or_issue="<link/issue-id>")) so the suite stays green while tracking the
bug; update the parametrization where the tuples are defined to apply this
change around the ([32, 64, 16], [128, 256, 64], 8, 8) entry.

---

Nitpick comments:
In `@tests/attention/test_cute_dsl_fmha_prefill.py`:
- Around line 309-315: The test currently calls
flashinfer.single_prefill_with_kv_cache(..., backend="auto") which risks
resolving to the same cute-dsl implementation and turning the comparison into a
self-check; change the backend argument to a fixed, explicit reference backend
(e.g., "reference" or the specific backend name used elsewhere) in the calls to
single_prefill_with_kv_cache and any other occurrences (including the similar
block around lines 550-564) so the test compares cute-dsl against a stable,
non-moving implementation; if you want to keep an "auto" dispatch test, add a
separate test that asserts the resolved backend explicitly before comparing
outputs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d0393ae8-1c54-43f9-9442-98045cefd6ef

📥 Commits

Reviewing files that changed from the base of the PR and between a1166dc and 27cd0aa.

📒 Files selected for processing (7)
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/attention_dsl/__init__.py
  • flashinfer/attention_dsl/cute_dsl/__init__.py
  • flashinfer/attention_dsl/cute_dsl/fmha.py
  • flashinfer/prefill.py
  • tests/attention/test_cute_dsl_fmha_prefill.py

Comment thread benchmarks/routines/attention.py Outdated
Comment thread benchmarks/routines/attention.py Outdated
Comment thread flashinfer/attention/cute_dsl/__init__.py
Comment thread flashinfer/attention/cute_dsl/fmha.py
Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/prefill.py Outdated
Comment thread tests/attention/test_cute_dsl_fmha_prefill.py Outdated
@nvpohanh
Copy link
Copy Markdown
Contributor

cc @leejnau for cute dsl prefill MLA kernels

@nvpohanh
Copy link
Copy Markdown
Contributor

If this is for DSR1 MLA prefill, please connect it to the trtllm_ragged_attention_deepseek() API. This is what the frameworks are using for MLA prefill on SM10x GPUs.
https://github.com/limin2021/flashinfer/blob/27cd0aab33fe043d7d17e9f1d9fee4ef3a4c3b38/flashinfer/prefill.py#L3696

…mpat

- Preload cute-dsl kernel .so in plan() for fail-fast and reuse in run()
- Default FP8 input to bf16 output dtype in plan() (not just run())
- Use host-side indptr for max seq len to avoid D2H during graph capture
- Use self.device instead of qo_indptr.device for compute capability check
- Add front-padding for cute-dsl varlen kernel in benchmark
- Enable CUDA graph for cute-dsl backend in benchmark
- Re-enable asymmetric (S_q < S_k) test case

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
benchmarks/routines/attention.py (2)

1868-1869: ⚠️ Potential issue | 🟠 Major

Forward FP8 scales through the cute-dsl benchmark call.

q, k, and v are quantized above, but this branch drops q_scale, k_scale, and v_scale. For FP8 cases, the cute-dsl benchmark and refcheck path will run with the wrong dequantization math.

Suggested fix
         elif backend == "cute-dsl":
-            return backend_wrappers[backend].run(q, k, v)
+            return backend_wrappers[backend].run(
+                q,
+                k,
+                v,
+                q_scale=q_scale,
+                k_scale=k_scale,
+                v_scale=v_scale,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1868 - 1869, The cute-dsl
branch in benchmarks/routines/attention.py drops FP8 dequantization scales;
update the backend == "cute-dsl" branch to pass q_scale, k_scale, and v_scale
through to backend_wrappers[backend].run so the cute-dsl benchmark and its
refcheck receive the FP8 scale values (mirror how other backends are called),
ensuring run(...) accepts and uses these additional arguments for correct
dequantization math.

1776-1783: ⚠️ Potential issue | 🟠 Major

Keep CUDA graphs disabled for cute-dsl in both the wrapper and timer path.

The PR explicitly treats cute-dsl as non-graph-compatible, but these checks still enable graph capture for it. That leaves the benchmark exercising the TVM-FFI path you were trying to exclude.

Suggested fix
                 flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
                     workspace_buffer,
                     "NHD",
                     use_cuda_graph=is_cuda_graph_compatible
-                    if backend not in ["fa2"]
+                    if backend not in ["fa2", "cute-dsl"]
                     else False,
                     qo_indptr_buf=qo_indptr,
                     kv_indptr_buf=kv_indptr,
                     backend=backend,
                 )
-            use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ["fa2"]),
+            use_cuda_graph=(
+                is_cuda_graph_compatible
+                and cur_backend not in ["fa2", "cute-dsl"]
+            ),

Also applies to: 1954-1960

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1776 - 1783, The CUDA-graph
flag is still being enabled for "cute-dsl"; update the graph-compatibility
checks so CUDA graphs are disabled for both "fa2" and "cute-dsl". Concretely,
change usages of is_cuda_graph_compatible in the
BatchPrefillWithRaggedKVCacheWrapper call (backend_wrappers[backend]) to use
is_cuda_graph_compatible if backend not in ["fa2", "cute-dsl"] else False, and
make the identical change in the timer path where the timer/wrapper is
constructed (the counterpart block around lines ~1954-1960) so both the wrapper
and timer consistently treat "cute-dsl" as non-graph-compatible. Ensure you
reference and update the same is_cuda_graph_compatible conditional in both
places.
flashinfer/attention_dsl/cute_dsl/fmha.py (1)

113-125: ⚠️ Potential issue | 🟠 Major

Prevent artifact_path from escaping the cubin cache root.

When FLASHINFER_DSL_FMHA_ARTIFACT_PATH is unset, Line 114 becomes "/<variant>.so". FLASHINFER_CUBIN_DIR / artifact_path then resolves outside the cache directory, so this loader can read/write the artifact from the wrong location.

Suggested fix
     so_filename = f"{variant_name}.so"
-    artifact_path = f"{DSL_FMHA_ARTIFACT_PATH}/{so_filename}"
+    artifact_path = (
+        f"{DSL_FMHA_ARTIFACT_PATH.rstrip('/')}/{so_filename}"
+        if DSL_FMHA_ARTIFACT_PATH
+        else so_filename
+    )
     sha256 = DSL_FMHA_CHECKSUMS.get(variant_name, "")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 113 - 125, The code
builds artifact_path from DSL_FMHA_ARTIFACT_PATH and variant_name which can
produce a leading slash and allow FLASHINFER_CUBIN_DIR / artifact_path to escape
the cache; update the logic around artifact_path/local_path in fmha loader
(referencing variant_name, artifact_path, FLASHINFER_CUBIN_DIR,
FLASHINFER_DSL_FMHA_ARTIFACT_PATH, get_artifact, local_path) to ensure
artifact_path is normalized to a relative path (strip any leading slashes,
resolve .. segments) and then construct local_path and assert that
local_path.resolve().is_relative_to(FLASHINFER_CUBIN_DIR.resolve()) (or compare
prefixes) before calling get_artifact; if the check fails, raise a clear
RuntimeError.
flashinfer/prefill.py (1)

1324-1338: ⚠️ Potential issue | 🟠 Major

Normalize FP8 scale inputs before the cute-dsl dispatch.

This branch still breaks the FP8 user path in two ways: omitted scales were already expanded to per-head tensors above and now fail this validation, while a 1-element tensor that passes validation is then turned into 1.0 by _split_scale_param(). That means either an unexpected ValueError or silently wrong scaling.

Suggested fix
+        def _scalarize_scale(scale, name: str) -> float:
+            if scale is None:
+                return 1.0
+            if isinstance(scale, torch.Tensor):
+                if scale.numel() == 1:
+                    return float(scale.item())
+                if torch.all(scale == scale.reshape(-1)[0]).item():
+                    return float(scale.reshape(-1)[0].item())
+                raise ValueError(
+                    f"cute-dsl backend does not support per-head scale tensors ({name}), "
+                    "only per-tensor scalar scales are supported"
+                )
+            return float(scale)
+
-        if is_float8(q):
-            for s, name in (
-                (scale_q, "scale_q"),
-                (scale_k, "scale_k"),
-                (scale_v, "scale_v"),
-            ):
-                if isinstance(s, torch.Tensor) and s.numel() > 1:
-                    raise ValueError(
-                        f"cute-dsl backend does not support per-head scale tensors ({name}), "
-                        "only per-tensor scalar scales are supported"
-                    )
-        # Extract scalar scale values for DSL kernel
-        _, sq = _split_scale_param(scale_q)
-        _, sk = _split_scale_param(scale_k)
-        _, sv = _split_scale_param(scale_v)
+        sq = _scalarize_scale(scale_q, "scale_q")
+        sk = _scalarize_scale(scale_k, "scale_k")
+        sv = _scalarize_scale(scale_v, "scale_v")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1324 - 1338, Normalize FP8 scale inputs
before the cute-dsl dispatch: call _split_scale_param(scale_q/scale_k/scale_v)
first to extract the original param and the scalar component (e.g., _, sq =
_split_scale_param(scale_q)) and then perform the per-head tensor validation
using the original param (the first return) rather than the possibly collapsed
value; this ensures omitted scales that were expanded earlier are normalized to
scalars and 1-element tensors are treated as scalars for the DSL kernel while
still raising ValueError for true per-head tensors.
🧹 Nitpick comments (1)
flashinfer/attention_dsl/cute_dsl/fmha.py (1)

466-468: Drop the unused total_kv binding.

total_kv is never read, so this now fails Ruff and adds noise in the wrapper.

Suggested fix
-    total_q, H_q, D = q.shape
-    total_kv, H_k, _ = k.shape
+    total_q, H_q, D = q.shape
+    _, H_k, _ = k.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 466 - 468, The
variable total_kv is unused and triggers linter noise; update the k.shape
unpacking in fmha.py to ignore that element (e.g., replace total_kv with _ in
the tuple assignment) so only the needed H_k (and any used dimensions) are
bound, removing the unused binding from the wrapper logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/prefill.py`:
- Around line 3072-3093: When selecting the "cute-dsl" backend in the prefill
plan, reject wrapper layouts using HND to avoid silently feeding misordered
tensors into cute_dsl_fmha_ragged_prefill; add a check alongside the other
backend validations (the block that checks get_compute_capability,
pos_encoding_mode, packed_custom_mask, logits_soft_cap, use_fp16_qk_reduction)
to raise a ValueError if self._kv_layout (or the local kv_layout variable) ==
"HND", referencing self._kv_layout and the cute_dsl_fmha_ragged_prefill behavior
that ignores the wrapper layout.
- Around line 3334-3355: The DSL ragged kernel expects buffers to be
front-padded because it uses negative pointer offsets, but the current call to
cute_dsl_fmha_ragged_prefill forwards user tensors that start at storage offset
0; to fix, allocate padded versions of q, k, v, and out with extra prefix space
(pad length = max_len or total ragged padding as the tests do), copy the
original tensors into the tail slice of those padded tensors, adjust any
indptr/pointer buffers if needed, and pass these padded/tail-sliced tensors to
cute_dsl_fmha_ragged_prefill instead of the original q/k/v/out; reference the
call site around cute_dsl_fmha_ragged_prefill and the buffers
_qo_indptr_buf/_kv_indptr_buf to locate where to insert the padding/wrapping
logic.

In `@tests/attention/test_cute_dsl_fmha_prefill.py`:
- Around line 20-37: The skip condition currently blocks all SM10x devices by
calling is_sm100a_supported; replace that check with a function that accepts the
whole SM10x family (e.g., is_sm10x_supported) so SM10.3 is allowed: in the
pytestmark list change the second skipif predicate from "not
torch.cuda.is_available() or not is_sm100a_supported(torch.device('cuda'))" to
"not torch.cuda.is_available() or not is_sm10x_supported(torch.device('cuda'))"
(or, if such helper doesn't exist, implement a small helper in flashinfer.utils
that checks CUDA compute capability major==10 and call it here). Ensure you
reference the existing symbol is_sm100a_supported in the change so reviewers can
locate and replace it with is_sm10x_supported (or the new helper).

---

Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1868-1869: The cute-dsl branch in benchmarks/routines/attention.py
drops FP8 dequantization scales; update the backend == "cute-dsl" branch to pass
q_scale, k_scale, and v_scale through to backend_wrappers[backend].run so the
cute-dsl benchmark and its refcheck receive the FP8 scale values (mirror how
other backends are called), ensuring run(...) accepts and uses these additional
arguments for correct dequantization math.
- Around line 1776-1783: The CUDA-graph flag is still being enabled for
"cute-dsl"; update the graph-compatibility checks so CUDA graphs are disabled
for both "fa2" and "cute-dsl". Concretely, change usages of
is_cuda_graph_compatible in the BatchPrefillWithRaggedKVCacheWrapper call
(backend_wrappers[backend]) to use is_cuda_graph_compatible if backend not in
["fa2", "cute-dsl"] else False, and make the identical change in the timer path
where the timer/wrapper is constructed (the counterpart block around lines
~1954-1960) so both the wrapper and timer consistently treat "cute-dsl" as
non-graph-compatible. Ensure you reference and update the same
is_cuda_graph_compatible conditional in both places.

In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 113-125: The code builds artifact_path from DSL_FMHA_ARTIFACT_PATH
and variant_name which can produce a leading slash and allow
FLASHINFER_CUBIN_DIR / artifact_path to escape the cache; update the logic
around artifact_path/local_path in fmha loader (referencing variant_name,
artifact_path, FLASHINFER_CUBIN_DIR, FLASHINFER_DSL_FMHA_ARTIFACT_PATH,
get_artifact, local_path) to ensure artifact_path is normalized to a relative
path (strip any leading slashes, resolve .. segments) and then construct
local_path and assert that
local_path.resolve().is_relative_to(FLASHINFER_CUBIN_DIR.resolve()) (or compare
prefixes) before calling get_artifact; if the check fails, raise a clear
RuntimeError.

In `@flashinfer/prefill.py`:
- Around line 1324-1338: Normalize FP8 scale inputs before the cute-dsl
dispatch: call _split_scale_param(scale_q/scale_k/scale_v) first to extract the
original param and the scalar component (e.g., _, sq =
_split_scale_param(scale_q)) and then perform the per-head tensor validation
using the original param (the first return) rather than the possibly collapsed
value; this ensures omitted scales that were expanded earlier are normalized to
scalars and 1-element tensors are treated as scalars for the DSL kernel while
still raising ValueError for true per-head tensors.

---

Nitpick comments:
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 466-468: The variable total_kv is unused and triggers linter
noise; update the k.shape unpacking in fmha.py to ignore that element (e.g.,
replace total_kv with _ in the tuple assignment) so only the needed H_k (and any
used dimensions) are bound, removing the unused binding from the wrapper logic.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: dc23d733-dca0-4057-b2b1-e199e8144a8a

📥 Commits

Reviewing files that changed from the base of the PR and between 27cd0aa and 7162590.

📒 Files selected for processing (4)
  • benchmarks/routines/attention.py
  • flashinfer/attention_dsl/cute_dsl/fmha.py
  • flashinfer/prefill.py
  • tests/attention/test_cute_dsl_fmha_prefill.py

Comment thread flashinfer/prefill.py Outdated
Comment thread flashinfer/prefill.py Outdated
Comment thread tests/attention/test_cute_dsl_fmha_prefill.py Outdated
…E and benchmark support

- Add `backend` param to trtllm_ragged_attention_deepseek to dispatch cute-dsl
- Add `with_lse` to DSL kernel variant selection for correct .so loading
- Front-pad output tensor for DSL varlen negative pointer offsets
- Add cute-dsl backend parametrize to test_trtllm_gen_prefill tests
- Route benchmark cute-dsl through trtllm_ragged_attention_deepseek directly

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (5)
flashinfer/attention_dsl/cute_dsl/fmha.py (2)

523-550: ⚠️ Potential issue | 🔴 Critical

Front-pad ragged buffers inside this helper before launching the varlen kernel.

The DSL ragged kernel uses negative pointer offsets, but this wrapper still forwards the user views directly. The current tests and benchmarks only stay safe because they pass slices from front-padded allocations; a normal contiguous q/k/v/o tensor here can underflow the allocation on the first access.

Also applies to: 554-631

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 523 - 550, The
wrapper is passing user views q/k/v/o directly to kernel_fn which expects ragged
buffers with front-padding because the DSL kernel uses negative pointer offsets;
fix by creating front-padded temporary tensors for q, k, v, o (e.g., pad_front_q
= torch.empty(front_pad + q.size(0), ...); copy q into the padded region) before
calling kernel_fn, and pass their .data_ptr() and the padded tensors (like q_4d)
to the kernel invocation; update both the enable_tvm_ffi branch around kernel_fn
(the block using q_4d/qo_indptr/etc.) and the similar branch later (lines
554-631) to ensure all varlen launches use front-padded buffers and preserve
device/dtype/contiguity and stream detection.

115-127: ⚠️ Potential issue | 🟠 Major

Keep artifact_path relative to FLASHINFER_CUBIN_DIR.

When FLASHINFER_DSL_FMHA_ARTIFACT_PATH is unset, this builds "/<variant>.so". FLASHINFER_CUBIN_DIR / artifact_path then ignores the cache root and reads/writes outside the intended cubin directory.

Suggested fix
-    artifact_path = f"{DSL_FMHA_ARTIFACT_PATH}/{so_filename}"
+    artifact_path = (
+        f"{DSL_FMHA_ARTIFACT_PATH.rstrip('/')}/{so_filename}"
+        if DSL_FMHA_ARTIFACT_PATH
+        else so_filename
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/attention_dsl/cute_dsl/fmha.py` around lines 115 - 127,
artifact_path may start with a leading slash when DSL_FMHA_ARTIFACT_PATH is
empty, causing FLASHINFER_CUBIN_DIR / artifact_path to ignore the cache root;
fix by ensuring artifact_path is always relative: build it by joining
DSL_FMHA_ARTIFACT_PATH and so_filename while stripping any leading/trailing
slashes or falling back to just so_filename (e.g., compute base =
DSL_FMHA_ARTIFACT_PATH.strip("/") and set artifact_path =
f"{base}/{so_filename}" if base else so_filename), then use local_path =
FLASHINFER_CUBIN_DIR / artifact_path and pass artifact_path to get_artifact.
flashinfer/prefill.py (2)

1324-1338: ⚠️ Potential issue | 🟠 Major

Cute-dsl FP8 scale handling still breaks the scalar/default case.

By the time this branch runs, omitted FP8 scales were already materialized as per-head tensors, so backend="cute-dsl" rejects the default 1.0 case. A 1-element tensor also passes this guard but _split_scale_param() converts it back to 1.0, dropping the caller’s actual scalar.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 1324 - 1338, The cute-dsl per-head scale
guard currently rejects 1-element tensors because it runs before converting
scales to scalar values; change the logic to extract scalar scale values first
(call _split_scale_param for scale_q, scale_k, scale_v) and then enforce the
per-head restriction only when a scale is a tensor with numel() > 1, or
alternatively detect and treat 1-element tensors as scalars before the
is_float8/cute-dsl check so that 1-element/materialized default scales are
accepted instead of raising in the block that references is_float8,
scale_q/scale_k/scale_v and _split_scale_param.

3072-3093: ⚠️ Potential issue | 🟠 Major

Reject kv_layout="HND" for ragged cute-dsl plans.

run() forwards raw k/v tensors into cute_dsl_fmha_ragged_prefill() and never reorders them for the wrapper layout. A cute-dsl wrapper planned with HND will therefore feed misordered tensors to the kernel.

Suggested fix
             if self._backend == "cute-dsl":
                 from .utils import get_compute_capability

                 cc = get_compute_capability(self.device)
                 if cc[0] != 10:
                     raise RuntimeError(
                         f"cute-dsl backend (FMHA prefill kernel) requires SM10x (Blackwell), got SM{cc[0]}{cc[1]}"
                     )
+                if self._kv_layout != "NHD":
+                    raise ValueError("cute-dsl backend only supports NHD layout")
                 if pos_encoding_mode != "NONE":
                     raise ValueError(
                         f"cute-dsl backend does not support pos_encoding_mode={pos_encoding_mode}"
                     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/prefill.py` around lines 3072 - 3093, The cute-dsl branch must
reject wrapper plans that use kv_layout="HND" for ragged prefill because run()
forwards raw k/v into cute_dsl_fmha_ragged_prefill() without reordering; in the
cute-dsl handling block (the code that checks self._backend == "cute-dsl") add a
guard that checks the kv_layout variable (and the plan ragged flag, e.g.,
plan.ragged or is_ragged) and raise a ValueError if kv_layout == "HND" and the
plan is ragged, with a clear message like "cute-dsl ragged plans do not support
kv_layout='HND'".
benchmarks/routines/attention.py (1)

1981-2006: ⚠️ Potential issue | 🟠 Major

Disable CUDA graph capture in both cute-dsl timer paths.

These benchmark calls still pass use_cuda_graph=True for cute-dsl, so non-CUPTI runs can graph-capture the TVM-FFI path you were trying to exclude.

Suggested fix
-            use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in ["fa2"]),
+            use_cuda_graph=(
+                is_cuda_graph_compatible
+                and cur_backend not in {"fa2", "cute-dsl"}
+            ),

Apply the same change to the MLA bench_gpu_time(...) call as well.

Also applies to: 2443-2463

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 1981 - 2006, The
bench_gpu_time calls for the TVM/FFI ("cute-dsl") backends still allow CUDA
graph capture; update the use_cuda_graph argument in the bench_gpu_time
invocation(s) (e.g., the call with fn=run_backend_wrapper using cur_backend, and
the MLA bench_gpu_time call) so that use_cuda_graph is False when cur_backend ==
"cute-dsl" (e.g., use_cuda_graph=(is_cuda_graph_compatible and cur_backend not
in ["fa2","cute-dsl"])). Ensure both occurrences (the shown run_backend_wrapper
call and the MLA call around lines 2443-2463) are changed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 303-306: The code always loads the non-LSE cubin because
get_cute_dsl_fmha_kernel is invoked without selecting the LSE symbol; update the
call site in the kernel acquisition (kernel_fn = get_cute_dsl_fmha_kernel(...))
to request the fixed-length LSE variant when LSE/return_lse is requested (e.g.,
pass the lse/return_lse flag or construct/choose the symbol name with the "_lse"
suffix) so that the returned kernel_fn points to the correct LSE-enabled symbol
rather than the default non-LSE cubin.

In `@flashinfer/prefill.py`:
- Around line 3094-3105: The preload in plan() always sets self._cached_module
using get_cute_dsl_fmha_kernel without with_lse, which forces
run(return_lse=True) to reuse the non-LSE cubin; change the caching so you store
kernels keyed by the with_lse flag (e.g., a dict on self like
self._cached_modules[(q_data_type,o_data_type,head_dim_qk,causal,with_lse)]) and
call get_cute_dsl_fmha_kernel with with_lse=True when run requests
return_lse=True, ensuring run() looks up the correct cached variant instead of
always using self._cached_module.

---

Duplicate comments:
In `@benchmarks/routines/attention.py`:
- Around line 1981-2006: The bench_gpu_time calls for the TVM/FFI ("cute-dsl")
backends still allow CUDA graph capture; update the use_cuda_graph argument in
the bench_gpu_time invocation(s) (e.g., the call with fn=run_backend_wrapper
using cur_backend, and the MLA bench_gpu_time call) so that use_cuda_graph is
False when cur_backend == "cute-dsl" (e.g.,
use_cuda_graph=(is_cuda_graph_compatible and cur_backend not in
["fa2","cute-dsl"])). Ensure both occurrences (the shown run_backend_wrapper
call and the MLA call around lines 2443-2463) are changed.

In `@flashinfer/attention_dsl/cute_dsl/fmha.py`:
- Around line 523-550: The wrapper is passing user views q/k/v/o directly to
kernel_fn which expects ragged buffers with front-padding because the DSL kernel
uses negative pointer offsets; fix by creating front-padded temporary tensors
for q, k, v, o (e.g., pad_front_q = torch.empty(front_pad + q.size(0), ...);
copy q into the padded region) before calling kernel_fn, and pass their
.data_ptr() and the padded tensors (like q_4d) to the kernel invocation; update
both the enable_tvm_ffi branch around kernel_fn (the block using
q_4d/qo_indptr/etc.) and the similar branch later (lines 554-631) to ensure all
varlen launches use front-padded buffers and preserve device/dtype/contiguity
and stream detection.
- Around line 115-127: artifact_path may start with a leading slash when
DSL_FMHA_ARTIFACT_PATH is empty, causing FLASHINFER_CUBIN_DIR / artifact_path to
ignore the cache root; fix by ensuring artifact_path is always relative: build
it by joining DSL_FMHA_ARTIFACT_PATH and so_filename while stripping any
leading/trailing slashes or falling back to just so_filename (e.g., compute base
= DSL_FMHA_ARTIFACT_PATH.strip("/") and set artifact_path =
f"{base}/{so_filename}" if base else so_filename), then use local_path =
FLASHINFER_CUBIN_DIR / artifact_path and pass artifact_path to get_artifact.

In `@flashinfer/prefill.py`:
- Around line 1324-1338: The cute-dsl per-head scale guard currently rejects
1-element tensors because it runs before converting scales to scalar values;
change the logic to extract scalar scale values first (call _split_scale_param
for scale_q, scale_k, scale_v) and then enforce the per-head restriction only
when a scale is a tensor with numel() > 1, or alternatively detect and treat
1-element tensors as scalars before the is_float8/cute-dsl check so that
1-element/materialized default scales are accepted instead of raising in the
block that references is_float8, scale_q/scale_k/scale_v and _split_scale_param.
- Around line 3072-3093: The cute-dsl branch must reject wrapper plans that use
kv_layout="HND" for ragged prefill because run() forwards raw k/v into
cute_dsl_fmha_ragged_prefill() without reordering; in the cute-dsl handling
block (the code that checks self._backend == "cute-dsl") add a guard that checks
the kv_layout variable (and the plan ragged flag, e.g., plan.ragged or
is_ragged) and raise a ValueError if kv_layout == "HND" and the plan is ragged,
with a clear message like "cute-dsl ragged plans do not support
kv_layout='HND'".
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f6d09c2b-8008-4f12-8f9c-a7f7e71e9dbb

📥 Commits

Reviewing files that changed from the base of the PR and between 7162590 and f574404.

📒 Files selected for processing (4)
  • benchmarks/routines/attention.py
  • flashinfer/attention_dsl/cute_dsl/fmha.py
  • flashinfer/prefill.py
  • tests/attention/test_trtllm_gen_attention.py

Comment thread flashinfer/attention_dsl/cute_dsl/fmha.py Outdated
Comment thread flashinfer/prefill.py Outdated
@nvpohanh
Copy link
Copy Markdown
Contributor

@limin2021 please share some perf numbers if possible. thanks!

- Update DSL_FMHA artifact path to latest CI build (b0adf88)
- Add aarch64 checksums for sm_100a, sm_103a, sm_110a
- Make artifact paths and checksums arch-aware (cpu_arch/sm_arch)
- Fix varlen kernel to always use non-persistent mode

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@limin2021 limin2021 requested a review from qsang-nv as a code owner April 16, 2026 05:42
limin2021 and others added 3 commits April 15, 2026 23:09
…_dsl to attention/cute_dsl

- attention.py → attention/_core.py (re-exported via attention/__init__.py)
- attention_dsl/cute_dsl/ → attention/cute_dsl/ (consistent with mla/cute_dsl/, fused_moe/cute_dsl/)
- Update relative imports in _core.py (. → ..)
- Update import paths in prefill.py and test_cute_dsl_fmha_prefill.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove bf16/fp16 from parametrize (FP8 suffices for cubin load validation)
- Update docstring to reflect artifactory-first usage

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…quirement

- Reuse _get_host_cpu_arch from artifacts.py instead of duplicating in fmha.py
- Add clarifying comment on DSL_FMHA_CHECKSUMS (manifest hash, not kernel hash)
- Document front-padding requirement in cute_dsl_fmha_ragged_prefill docstring
- Add backend parameter docs in trtllm_ragged_attention_deepseek with front-padding note

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@limin2021 limin2021 changed the title Title: feat: Integrate CuTe DSL FMHA cubin kernels into prefill backend feat: Integrate CuTe DSL FMHA cubin kernels into prefill backend Apr 16, 2026
@limin2021 limin2021 changed the title feat: Integrate CuTe DSL FMHA cubin kernels into prefill backend feat: Integrate CuTe DSL FMHA prefill kernels by loading from cubin Apr 16, 2026
@limin2021 limin2021 changed the title feat: Integrate CuTe DSL FMHA prefill kernels by loading from cubin feat: Integrate CuTe DSL FMHA prefill kernels by loading cubin Apr 16, 2026
limin2021 and others added 2 commits April 16, 2026 00:41
…_dim_qk

- Assert query dtype is fp16/bf16/fp8_e4m3fn in cute-dsl path
- Remove duplicate head_dim_qk assignment in test_trtllm_gen_prefill

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add test_trtllm_gen_prefill_fp8 with DeepSeek-R1 config (H=128, 8K seqlen)
- Test both mla_dimensions (h192/h128), causal/non-causal, skip-softmax
- Remove standalone test_cute_dsl_fmha_prefill.py (moved to fmha-cubin-integration)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@aleozlx aleozlx added the run-ci label Apr 16, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Apr 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !559 has been created, and the CI pipeline #48673616 is currently running. I'll report back once the pipeline job completes.

Comment thread flashinfer/prefill.py Outdated
…attention_deepseek

Aligns with the naming convention used across other flashinfer functions,
since this codepath uses the trtllm-gen module.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@limin2021
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !559 has been updated with latest changes, and the CI pipeline #48944891 is currently running. I'll report back once the pipeline job completes.

@limin2021
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !559 has been updated with latest changes, and the CI pipeline #48952893 is currently running. I'll report back once the pipeline job completes.

Update DSL_FMHA path hash from b0adf88 to c770c91c to point at the
latest cubin release on artifactory, and refresh all 6 checksums.txt
SHA256 hashes (x86_64/aarch64 × sm_100a/sm_103a/sm_110a).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@limin2021
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !559 has been updated with latest changes, and the CI pipeline #48957301 is currently running. I'll report back once the pipeline job completes.

The test was not updated when DSL_FMHA artifact subdirectories were
added to get_subdir_file_list. Without the new mocks, the test tried
to download checksums.txt from real URLs and failed with
FileNotFoundError. Pin cpu_arch to x86_64 for deterministic mocks
regardless of the runner architecture.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@limin2021
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !559 has been updated with latest changes, and the CI pipeline #49051547 is currently running. I'll report back once the pipeline job completes.

Comment thread flashinfer/attention/cute_dsl/fmha.py
Comment thread flashinfer/prefill.py
yzh119 and others added 2 commits April 21, 2026 00:42
Previously _get_gpu_arch() was called inside get_cute_dsl_fmha_kernel
and relied on the current default CUDA device. With @functools.cache
the arch was frozen at first call, so on heterogeneous multi-GPU nodes
subsequent calls on a different-arch device would silently reuse the
wrong cubin.

Align with the pattern used by get_fp4_quantization_module: caller
computes arch from the tensor's device and passes it as a parameter,
making arch part of the cache key.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants