Skip to content

[feat] Add blackwell GDN prefill kernel#3001

Merged
kahyunnam merged 12 commits intoflashinfer-ai:mainfrom
jiahanc:BlackwellGDN
Apr 13, 2026
Merged

[feat] Add blackwell GDN prefill kernel#3001
kahyunnam merged 12 commits intoflashinfer-ai:mainfrom
jiahanc:BlackwellGDN

Conversation

@jiahanc
Copy link
Copy Markdown
Collaborator

@jiahanc jiahanc commented Apr 7, 2026

📌 Description

  • Add Blackwell GDN prefill kernel written in cutedsl
GPU: NVIDIA B200 [Blackwell (SM100)]
Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128
FLA 0.4.2, Triton 3.5.1, PyTorch 2.11.0+cu130

Heads            Seqlens           h_qk  h_v    FI Blackwell (SM100)   TFLOPS  FLA/Triton   Speedup
---------------------------------------------------------------------------------------------------
397B/122B TP8    1x8192               2    8                  0.330ms    12.8      0.338ms     1.02x
397B/122B TP8    1x4096               2    8                  0.180ms    12.0      0.215ms     1.19x
397B/122B TP8    1x2048               2    8                  0.101ms    10.7      0.213ms     2.11x
397B/122B TP8    6144+2048            2    8                  0.258ms    16.6      0.285ms     1.10x
397B/122B TP8    4096+4096            2    8                  0.179ms    23.9      0.239ms     1.34x
397B/122B TP8    2048+6144            2    8                  0.256ms    16.8      0.285ms     1.11x
397B/122B TP8    1024+7168            2    8                  0.296ms    14.5      0.308ms     1.04x
397B/122B TP8    2048x4               2    8                  0.101ms    42.3      0.208ms     2.06x
397B/122B TP8    1024x8               2    8                  0.063ms    68.2      0.214ms     3.40x

397B/122B TP4    1x8192               4   16                  0.368ms    23.3      0.424ms     1.15x
397B/122B TP4    1x4096               4   16                  0.178ms    24.2      0.235ms     1.32x
397B/122B TP4    1x2048               4   16                  0.100ms    21.4      0.220ms     2.20x
397B/122B TP4    6144+2048            4   16                  0.257ms    33.5      0.377ms     1.47x
397B/122B TP4    4096+4096            4   16                  0.179ms    47.9      0.329ms     1.84x
397B/122B TP4    2048+6144            4   16                  0.259ms    33.2      0.376ms     1.45x
397B/122B TP4    1024+7168            4   16                  0.298ms    28.9      0.401ms     1.35x
397B/122B TP4    2048x4               4   16                  0.104ms    83.0      0.334ms     3.21x
397B/122B TP4    1024x8               4   16                  0.068ms   126.3      0.342ms     5.03x

397B/122B TP2    1x8192               8   32                  0.336ms    51.1      0.602ms     1.79x
397B/122B TP2    1x4096               8   32                  0.180ms    47.7      0.334ms     1.86x
397B/122B TP2    1x2048               8   32                  0.102ms    42.3      0.228ms     2.24x
397B/122B TP2    6144+2048            8   32                  0.258ms    66.5      0.605ms     2.35x
397B/122B TP2    4096+4096            8   32                  0.182ms    94.4      0.605ms     3.32x
397B/122B TP2    2048+6144            8   32                  0.260ms    66.0      0.606ms     2.33x
397B/122B TP2    1024+7168            8   32                  0.299ms    57.4      0.606ms     2.03x
397B/122B TP2    2048x4               8   32                  0.107ms   160.8      0.613ms     5.73x
397B/122B TP2    1024x8               8   32                  0.124ms   138.7      0.605ms     4.88x

397B/122B TP1    1x8192              16   64                  0.339ms   101.4      1.021ms     3.01x
397B/122B TP1    1x4096              16   64                  0.182ms    94.4      0.539ms     2.96x
397B/122B TP1    1x2048              16   64                  0.103ms    83.2      0.302ms     2.93x
397B/122B TP1    6144+2048           16   64                  0.263ms   130.6      1.017ms     3.87x
397B/122B TP1    4096+4096           16   64                  0.187ms   184.0      1.021ms     5.46x
397B/122B TP1    2048+6144           16   64                  0.265ms   129.5      1.022ms     3.86x
397B/122B TP1    1024+7168           16   64                  0.304ms   113.0      1.024ms     3.37x
397B/122B TP1    2048x4              16   64                  0.203ms   169.4      1.033ms     5.09x
397B/122B TP1    1024x8              16   64                  0.235ms   146.0      1.031ms     4.39x

35B/9B/4B TP1    1x8192              16   32                  0.339ms    50.7      0.602ms     1.78x
35B/9B/4B TP1    1x4096              16   32                  0.181ms    47.5      0.333ms     1.84x
35B/9B/4B TP1    1x2048              16   32                  0.102ms    42.2      0.220ms     2.16x
35B/9B/4B TP1    6144+2048           16   32                  0.259ms    66.3      0.605ms     2.34x
35B/9B/4B TP1    4096+4096           16   32                  0.181ms    94.7      0.604ms     3.34x
35B/9B/4B TP1    2048+6144           16   32                  0.261ms    65.8      0.606ms     2.32x
35B/9B/4B TP1    1024+7168           16   32                  0.300ms    57.3      0.607ms     2.02x
35B/9B/4B TP1    2048x4              16   32                  0.106ms   162.6      0.613ms     5.78x
35B/9B/4B TP1    1024x8              16   32                  0.123ms   139.6      0.606ms     4.93x

27B TP1          1x8192              16   48                  0.338ms    76.3      0.847ms     2.51x
27B TP1          1x4096              16   48                  0.180ms    71.7      0.461ms     2.56x
27B TP1          1x2048              16   48                  0.102ms    63.2      0.254ms     2.49x
27B TP1          6144+2048           16   48                  0.261ms    98.9      0.789ms     3.02x
27B TP1          4096+4096           16   48                  0.184ms   140.4      0.850ms     4.62x
27B TP1          2048+6144           16   48                  0.262ms    98.5      0.853ms     3.26x
27B TP1          1024+7168           16   48                  0.300ms    85.8      0.854ms     2.85x
27B TP1          2048x4              16   48                  0.200ms   129.1      0.801ms     4.01x
27B TP1          1024x8              16   48                  0.180ms   143.6      0.812ms     4.51x

2B/0.8B TP1      1x8192              16   16                  0.334ms    25.7      0.424ms     1.27x
2B/0.8B TP1      1x4096              16   16                  0.178ms    24.1      0.235ms     1.32x
2B/0.8B TP1      1x2048              16   16                  0.100ms    21.4      0.222ms     2.22x
2B/0.8B TP1      6144+2048           16   16                  0.255ms    33.7      0.378ms     1.48x
2B/0.8B TP1      4096+4096           16   16                  0.179ms    48.0      0.330ms     1.84x
2B/0.8B TP1      2048+6144           16   16                  0.255ms    33.7      0.377ms     1.48x
2B/0.8B TP1      1024+7168           16   16                  0.294ms    29.2      0.401ms     1.36x
2B/0.8B TP1      2048x4              16   16                  0.102ms    84.3      0.335ms     3.28x
2B/0.8B TP1      1024x8              16   16                  0.066ms   129.6      0.342ms     5.18x

Sym h32          1x8192              32   32                  0.335ms    51.3      0.602ms     1.80x
Sym h32          1x4096              32   32                  0.179ms    48.0      0.334ms     1.87x
Sym h32          1x2048              32   32                  0.101ms    42.4      0.221ms     2.19x
Sym h32          6144+2048           32   32                  0.258ms    66.6      0.604ms     2.34x
Sym h32          4096+4096           32   32                  0.181ms    94.7      0.605ms     3.34x
Sym h32          2048+6144           32   32                  0.258ms    66.6      0.605ms     2.35x
Sym h32          1024+7168           32   32                  0.296ms    58.0      0.606ms     2.05x
Sym h32          2048x4              32   32                  0.106ms   162.3      0.613ms     5.78x
Sym h32          1024x8              32   32                  0.123ms   139.5      0.605ms     4.92x

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added Blackwell (SM100/SM100A) GPU path for chunked gated delta-net prefill (requires head_size=128, CUDA 13+).
  • Chores

    • Optional kernel imports made more robust; package exposes SM100 probe flags and optional support.
    • Project and installer updated to declare and install SM100-capable CUTLASS extras when appropriate.
  • Tests

    • Test skips updated for SM100; relaxed numeric tolerances.
  • Benchmarks

    • New SM100-focused benchmark scripts and updated benchmark headers/output.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds Blackwell (SM100/SM100A) chunked GDN prefill support: new Blackwell package and scheduler, an SM100-specific compile-once kernel adapter with workspace management, top-level export handling updates, runtime branching in gdn_prefill between SM100 and SM90 paths, test/benchmark updates, and optional CUDA extras in project config.

Changes

Cohort / File(s) Summary
Blackwell package & kernels
flashinfer/gdn_kernels/blackwell/__init__.py, flashinfer/gdn_kernels/blackwell/gdn_prefill.py, flashinfer/gdn_kernels/blackwell/gated_delta_net_tile_scheduler.py
New Blackwell package exposing chunk_gated_delta_rule_sm100 and _has_blackwell_prefill; adds SM100 adapter with compile-once kernel cache, device workspace management, tensor conversion, and GDNTileScheduler (persistent/non-persistent).
Top-level kernel exports
flashinfer/gdn_kernels/__init__.py
Optional CuTe imports now treat ImportError and RuntimeError as failures; added _has_blackwell_prefill and chunk_gated_delta_rule_sm100 to __all__ (set to None/False when unavailable).
GDN prefill runtime
flashinfer/gdn_prefill.py
Runtime branching: select SM100 path when Blackwell + SM100A support and CUDA ≥13, else fallback to SM90. Default scale now computed as 1/sqrt(head_size) when not provided; SM100 path enforces head_size==128, materializes g/beta if needed, and uses int32 cu_seqlens.
Benchmarks & micro-bench refactor
benchmarks/bench_gdn_prefill.py, benchmarks/bench_blackwell_gdn_prefill.py
Refactor benchmark to use local run() wrapper, adjust warmup sync placement, pass kernel args by keyword, and add a new Blackwell benchmark file with CLI/sweep, device-name printing, and SM100+ enforcement.
Tests
tests/gdn/test_prefill_delta_rule.py
Generalized test skip helper to allow SM90 or SM100A (with CUDA version check) and relaxed non-bfloat16 numeric tolerance (atol increase).
Installer & project metadata
docker/install/install_python_packages.sh, pyproject.toml
Added optional dependency groups (cu12, cu13) for nvidia-cutlass-dsl; installer upgrades nvidia-cutlass-dsl[cu13]>=4.4.2 when CUDA_VERSION matches *cu13*.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Prefill as GDN Prefill
    participant DeviceCheck as Device/Version Check
    participant KernelCache as Compiled Kernel Cache
    participant CuTe as CuTe/CUTLASS (compile/exec)
    participant JIT as SM90 JIT Kernel

    User->>Prefill: call prefill(q,k,v,...)
    Prefill->>DeviceCheck: query _has_blackwell_prefill, is_sm100a_supported, CUDA version
    alt SM100 path (Blackwell + CUDA ≥13)
        DeviceCheck-->>Prefill: choose SM100 route
        Prefill->>KernelCache: lookup compiled kernel (dtype,HQ,HV,is_GQA,...)
        alt cache hit
            KernelCache-->>Prefill: return compiled callable + workspace info
        else cache miss
            Prefill->>CuTe: convert tensors, compile kernel, allocate workspace
            CuTe-->>KernelCache: store compiled callable & device info
            KernelCache-->>Prefill: return compiled callable
        end
        Prefill->>CuTe: execute compiled kernel with workspace & CUDA stream
        CuTe-->>Prefill: outputs + optional final state
    else SM90 fallback
        DeviceCheck-->>Prefill: choose SM90 route
        Prefill->>JIT: allocate workspace buffer, call SM90 JIT kernel
        JIT-->>Prefill: outputs
    end
    Prefill-->>User: return results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • yzh119
  • cyx-6
  • aleozlx
  • sricketts
  • samuellees
  • bkryu
  • jimmyzho
  • kahyunnam
  • nv-yunzheq
  • saltyminty

Poem

🐇 I hopped through kernels, bright and spry,
Cached the compile, watched tensors fly,
Blackwell tiles aligned just right,
Prefill hums through day and night—
Hop, compile, execute—bye-bye!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.30% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive PR description includes performance benchmarks and feature overview but lacks detailed explanation of implementation details, architecture choices, and specific changes across multiple files. Expand description to explain key implementation details (e.g., tile scheduling, persistent kernels, CuTe integration), rationale for SM100+ support with CUDA 13 requirement, and summarize major changes in each modified/added file.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title '[feat] Add blackwell GDN prefill kernel' is directly related to the main change, which is adding a Blackwell (SM100) GDN prefill kernel implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Gated Delta Net (GDN) chunked prefill kernels on Blackwell (SM100) GPUs. Key changes include the addition of a new tile scheduler and a Blackwell-specific adapter for the GDN kernel, along with updates to benchmarks and tests to incorporate the new SM100 path. Feedback highlights documentation inconsistencies regarding state tensor layouts in the Blackwell adapter, a redundant calculation in the tile scheduler, and an unused variable in the prefill logic.

Comment thread flashinfer/gdn_kernels/blackwell/gdn_prefill.py Outdated
Comment thread flashinfer/gdn_kernels/blackwell/gated_delta_net_tile_scheduler.py
Comment thread flashinfer/gdn_prefill.py
Comment thread benchmarks/bench_gdn_prefill.py Outdated
Comment thread tests/gdn/test_prefill_delta_rule.py Outdated
@nvpohanh
Copy link
Copy Markdown
Contributor

nvpohanh commented Apr 7, 2026

cc @kaixih @hlu1 @YAMY1234

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
tests/gdn/test_prefill_delta_rule.py (1)

32-42: Reuse the shared arch predicates in this skip helper.

Please lean on is_sm90a_supported() / is_sm100a_supported() for the architecture half of this check and keep the CUDA-major gate layered on top if SM100 still needs it. That keeps the tests aligned with the runtime support policy in one place.

As per coding guidelines, tests/**/*.py: Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gdn/test_prefill_delta_rule.py` around lines 32 - 42, Replace the
manual compute-capability checks in _skip_if_unsupported() with the shared
predicates: call is_sm90a_supported() and is_sm100a_supported() to decide
support, and only if is_sm100a_supported() is True still enforce the CUDA-major
gate by parsing torch.version.cuda (as currently done) to require CUDA 13+;
remove direct get_compute_capability() checks for SM90/SM100 and use those
utility functions so the test skip logic aligns with the runtime support policy.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Around line 49-57: The cached mutable CUDA workspace is currently shared
across devices in _get_compiled_cache (and the other cached spots around lines
121-123 and 224-228), causing cross-device reuse; update the cache to be
device-safe by scoping the cache entry to the current CUDA device: when building
the cache key for _get_compiled_cache (and the other cache-holding functions),
include the current device id (torch.cuda.current_device() or torch.device) or
change the cached value to a dict keyed by device id so each GPU gets its own
workspace tensor; ensure the workspace tensor is created on the correct device
before storing and returned only for that device.

In `@flashinfer/gdn_prefill.py`:
- Around line 201-233: The code currently allocates the full float32 scratch
state (output_state) before choosing the backend, causing an unnecessary large
allocation when output_final_state is False; change the logic so output_state is
only allocated when output_final_state is True and the selected backend requires
it (i.e., before calling chunk_gated_delta_rule_sm100), or pass an
already-conditional None otherwise. Concretely, move or guard the allocation of
output_state behind the backend selection branch that calls
chunk_gated_delta_rule_sm100 and only create the [num_seqs, H, 128, 128] tensor
when output_final_state is True; ensure the call to chunk_gated_delta_rule_sm100
still receives output_state when needed and None when not.
- Around line 198-201: The code currently treats scale==0.0 inconsistently
between SM90 and SM100 paths; fix this by normalizing the incoming scale before
any backend dispatch: in the prefill function compute a concrete _scale value
(e.g. if scale is None use 1.0/math.sqrt(head_size), and if scale == 0.0 also
set _scale = 1.0/math.sqrt(head_size) or alternatively raise ValueError) so that
the same _scale is used for both the SM100 branch (is_sm100a_supported(device) /
_has_blackwell_prefill) and the other path; update the logic around _scale,
scale, head_size, is_sm100a_supported and the dispatch blocks so they read this
single resolved _scale.

---

Nitpick comments:
In `@tests/gdn/test_prefill_delta_rule.py`:
- Around line 32-42: Replace the manual compute-capability checks in
_skip_if_unsupported() with the shared predicates: call is_sm90a_supported() and
is_sm100a_supported() to decide support, and only if is_sm100a_supported() is
True still enforce the CUDA-major gate by parsing torch.version.cuda (as
currently done) to require CUDA 13+; remove direct get_compute_capability()
checks for SM90/SM100 and use those utility functions so the test skip logic
aligns with the runtime support policy.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ef3072ef-612f-4d41-a948-df4a7c9633be

📥 Commits

Reviewing files that changed from the base of the PR and between e7f630c and b5e5188.

📒 Files selected for processing (9)
  • benchmarks/bench_gdn_prefill.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/blackwell/__init__.py
  • flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
  • flashinfer/gdn_kernels/blackwell/gated_delta_net_tile_scheduler.py
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py
  • flashinfer/gdn_prefill.py
  • pyproject.toml
  • tests/gdn/test_prefill_delta_rule.py

Comment thread flashinfer/gdn_kernels/blackwell/gdn_prefill.py
Comment thread flashinfer/gdn_prefill.py
Comment on lines +198 to +201
_scale = scale if scale is not None else 1.0 / math.sqrt(head_size)

_cuda_major = int(torch.version.cuda.split(".")[0]) if torch.version.cuda else 0
if _has_blackwell_prefill and is_sm100a_supported(device) and _cuda_major >= 13:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make scale=0.0 backend-independent.

The SM90 launcher still treats 0.0 as the “use default 1 / sqrt(d)” sentinel, but the SM100 path forwards 0.0 literally. The same API call can therefore produce different numerics on Hopper vs. Blackwell. Please reject scale == 0.0 at the Python boundary or resolve it to one concrete value before both dispatches.

Also applies to: 222-233, 241-252

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 198 - 201, The code currently treats
scale==0.0 inconsistently between SM90 and SM100 paths; fix this by normalizing
the incoming scale before any backend dispatch: in the prefill function compute
a concrete _scale value (e.g. if scale is None use 1.0/math.sqrt(head_size), and
if scale == 0.0 also set _scale = 1.0/math.sqrt(head_size) or alternatively
raise ValueError) so that the same _scale is used for both the SM100 branch
(is_sm100a_supported(device) / _has_blackwell_prefill) and the other path;
update the logic around _scale, scale, head_size, is_sm100a_supported and the
dispatch blocks so they read this single resolved _scale.

Comment thread flashinfer/gdn_prefill.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
flashinfer/gdn_prefill.py (2)

182-195: ⚠️ Potential issue | 🟠 Major

Avoid the eager output_state allocation on the SM100 no-final-state path.

When output_final_state=False, Lines 189-195 still allocate the full float32 state buffer up front, but Line 231 drops it for the SM100 launch. That buffer is unused on the Blackwell path and can be very large for bigger num_seqs/head counts.

Also applies to: 231-233

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 182 - 195, The code eagerly allocates
output_state even when output_final_state is False (in gdn_prefill.py where
output_state is set), which wastes memory for SM100/Blackwell because that path
drops the buffer later; change the logic so output_state is only allocated when
output_final_state is True or when the backend/device requires a CPU/GPU buffer
(e.g., detect device/backend used for the non-SM100 launch) — move or guard the
allocation away from the current elif block and defer it until after the
device/launch-path decision (see symbols output_final_state, output_state and
the SM100/Blackwell path around the later kernel launch where the buffer is
discarded at lines ~231-233) so no float32 buffer is created unnecessarily for
the SM100 no-final-state path.

197-201: ⚠️ Potential issue | 🟠 Major

Normalize scale == 0.0 before backend selection.

Lines 198-201 resolve None only. A caller that passes 0.0 still gets backend-dependent behavior: the SM100 path forwards literal zero at Line 232, while the SM90 path keeps using 0.0 as the auto-scale sentinel at Line 251. The same API call can therefore change numerics depending on the backend.

🧮 Proposed fix
-    _scale = scale if scale is not None else 1.0 / math.sqrt(head_size)
+    default_scale = 1.0 / math.sqrt(head_size)
+    _scale = default_scale if scale is None or scale == 0.0 else scale
@@
-            scale if scale is not None else 0.0,
+            _scale,

Also applies to: 222-233, 241-252

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_prefill.py` around lines 197 - 201, Normalize the sentinel
value for scale before choosing the backend: treat scale==0.0 the same as scale
is None and compute a single resolved value (e.g., resolved_scale = scale if
scale not None and scale != 0.0 else 1.0/math.sqrt(head_size)) once before the
SM100/SM90 conditional so both branches use the same numeric _scale; update
references to _scale (and any downstream use at the SM100 path that currently
forwards literal zero) to use resolved_scale to avoid backend-dependent
behavior. Ensure you change the same pattern around the other occurrences (the
blocks around the current 222-233 and 241-252 regions) so all branches read the
same normalized value.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/blackwell/__init__.py`:
- Around line 8-14: The optional-import guard in
flashinfer.gdn_kernels.blackwell currently only catches ImportError; update the
except clause in __init__.py to also catch RuntimeError (e.g., except
(ImportError, RuntimeError)) so that failures from the SM100 adapter are treated
as "backend unavailable", and ensure _has_blackwell_prefill is set to False and
chunk_gated_delta_rule_sm100 remains None (with the existing type ignore) in
that branch.

In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Line 195: The code creates a CUDA stream with
cuda.CUstream(torch.cuda.current_stream().cuda_stream) which uses the current
device rather than the tensor's device; update the calls that construct the
stream (the usage of torch.cuda.current_stream()) to pass the tensor's device
explicitly (use torch.cuda.current_stream(device=q.device)) so cuda.CUstream is
created from the correct device's stream; locate occurrences around the stream
variable creation and replace the plain torch.cuda.current_stream() calls with
torch.cuda.current_stream(device=q.device) to ensure kernels run on q.device's
stream.

In `@flashinfer/gdn_prefill.py`:
- Around line 159-162: Add the `@backend_requirement` decorator to the API whose
docstring mentions SM90/SM100 (the function or class that contains these
docstring lines) and import backend_requirement; implement the required
introspection methods is_compute_capability_supported(cc) and
is_backend_supported() on that API so they mirror the runtime gating (check
SM90/SM100 compute-capability logic, the SM100 head_size==128 constraint, and
any backend dependency check used at runtime), and ensure these methods return
booleans so callers can query support before dispatch.

---

Duplicate comments:
In `@flashinfer/gdn_prefill.py`:
- Around line 182-195: The code eagerly allocates output_state even when
output_final_state is False (in gdn_prefill.py where output_state is set), which
wastes memory for SM100/Blackwell because that path drops the buffer later;
change the logic so output_state is only allocated when output_final_state is
True or when the backend/device requires a CPU/GPU buffer (e.g., detect
device/backend used for the non-SM100 launch) — move or guard the allocation
away from the current elif block and defer it until after the device/launch-path
decision (see symbols output_final_state, output_state and the SM100/Blackwell
path around the later kernel launch where the buffer is discarded at lines
~231-233) so no float32 buffer is created unnecessarily for the SM100
no-final-state path.
- Around line 197-201: Normalize the sentinel value for scale before choosing
the backend: treat scale==0.0 the same as scale is None and compute a single
resolved value (e.g., resolved_scale = scale if scale not None and scale != 0.0
else 1.0/math.sqrt(head_size)) once before the SM100/SM90 conditional so both
branches use the same numeric _scale; update references to _scale (and any
downstream use at the SM100 path that currently forwards literal zero) to use
resolved_scale to avoid backend-dependent behavior. Ensure you change the same
pattern around the other occurrences (the blocks around the current 222-233 and
241-252 regions) so all branches read the same normalized value.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 48b22d3c-48b7-4ab8-a5fa-6ff81cf76a32

📥 Commits

Reviewing files that changed from the base of the PR and between 87787a8 and 9f5c7c8.

📒 Files selected for processing (9)
  • benchmarks/bench_gdn_prefill.py
  • flashinfer/gdn_kernels/__init__.py
  • flashinfer/gdn_kernels/blackwell/__init__.py
  • flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
  • flashinfer/gdn_kernels/blackwell/gated_delta_net_tile_scheduler.py
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py
  • flashinfer/gdn_prefill.py
  • pyproject.toml
  • tests/gdn/test_prefill_delta_rule.py
✅ Files skipped from review due to trivial changes (2)
  • pyproject.toml
  • benchmarks/bench_gdn_prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gdn/test_prefill_delta_rule.py

Comment thread flashinfer/gdn_kernels/blackwell/__init__.py
Comment thread flashinfer/gdn_kernels/blackwell/gdn_prefill.py Outdated
Comment thread flashinfer/gdn_prefill.py
@jiahanc
Copy link
Copy Markdown
Collaborator Author

jiahanc commented Apr 8, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !523 has been created, and the CI pipeline #47980418 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
flashinfer/gdn_kernels/blackwell/__init__.py (1)

16-19: Consider sorting __all__ alphabetically.

Static analysis flags that __all__ is not sorted. While functionally correct, sorting improves consistency.

♻️ Proposed fix
 __all__ = [
-    "chunk_gated_delta_rule_sm100",
     "_has_blackwell_prefill",
+    "chunk_gated_delta_rule_sm100",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell/__init__.py` around lines 16 - 19, The
__all__ list is not alphabetically sorted; update the list containing
"chunk_gated_delta_rule_sm100" and "_has_blackwell_prefill" so its entries are
in alphabetical order (i.e., place "_has_blackwell_prefill" before
"chunk_gated_delta_rule_sm100") to satisfy static-analysis sorting rules.
flashinfer/gdn_kernels/blackwell/gdn_prefill.py (1)

106-108: Add a clarifying comment for the is_GQA condition.

The condition is_GQA = HQ >= HV treats equal head counts (standard MHA) the same as cases where query heads exceed value heads. While the kernel logic handles this correctly (the output head count and repetition factor calculations both produce correct results when HQ == HV), the naming conflates GQA (grouped query attention, where HQ > HV) with MHA (where HQ == HV). Consider adding a brief comment explaining that this condition captures all cases where HQ >= HV for the kernel's internal dispatch logic, even though standard MHA (HQ == HV) is semantically distinct from true GQA.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py` around lines 106 - 108, Add
a short clarifying comment above the is_GQA assignment explaining that the
boolean is used for the kernel's internal dispatch logic and intentionally
groups HQ == HV (standard MHA) with HQ > HV (true GQA) so the kernel treats both
as the same path; reference the symbols HQ, HV, and is_GQA and mention that this
is a pragmatic choice for output head count and repetition factor calculations,
not a semantic conflation of MHA and GQA.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Around line 127-129: Replace the incorrect SM count retrieval: instead of
calling HardwareInfo.get_max_active_clusters(1) for num_sm, import and call
get_num_sm(q.device) from flashinfer.cute_dsl.utils and use that value for
num_sm passed into GatedDeltaNetChunkedKernel and workspace size calculations;
remove or leave HardwareInfo usage only for max_active_clusters as needed and
ensure get_max_active_clusters(1) is not used to compute num_sm.

---

Nitpick comments:
In `@flashinfer/gdn_kernels/blackwell/__init__.py`:
- Around line 16-19: The __all__ list is not alphabetically sorted; update the
list containing "chunk_gated_delta_rule_sm100" and "_has_blackwell_prefill" so
its entries are in alphabetical order (i.e., place "_has_blackwell_prefill"
before "chunk_gated_delta_rule_sm100") to satisfy static-analysis sorting rules.

In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py`:
- Around line 106-108: Add a short clarifying comment above the is_GQA
assignment explaining that the boolean is used for the kernel's internal
dispatch logic and intentionally groups HQ == HV (standard MHA) with HQ > HV
(true GQA) so the kernel treats both as the same path; reference the symbols HQ,
HV, and is_GQA and mention that this is a pragmatic choice for output head count
and repetition factor calculations, not a semantic conflation of MHA and GQA.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d9da8b4e-f9b5-4c4e-b287-5df251ce68df

📥 Commits

Reviewing files that changed from the base of the PR and between 9f5c7c8 and 8b88000.

📒 Files selected for processing (4)
  • docker/install/install_python_packages.sh
  • flashinfer/gdn_kernels/blackwell/__init__.py
  • flashinfer/gdn_kernels/blackwell/gdn_prefill.py
  • tests/gdn/test_prefill_delta_rule.py
✅ Files skipped from review due to trivial changes (1)
  • docker/install/install_python_packages.sh

Comment on lines +127 to +129
hardware_info = cutlass_utils.HardwareInfo()
num_sm = hardware_info.get_max_active_clusters(1)
max_active_clusters = hardware_info.get_max_active_clusters(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how num_sm is typically obtained in similar kernels
rg -n "num_sm\s*=" --type=py flashinfer/gdn_kernels/ -A2 -B2

Repository: flashinfer-ai/flashinfer

Length of output: 1964


🏁 Script executed:

#!/bin/bash
# Check HardwareInfo API for SM count method
ast-grep --pattern $'class HardwareInfo {
  $$$
}'

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Locate the HardwareInfo class definition
fd -type f -name "*.py" | xargs rg -l "class HardwareInfo" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

# Check cutlass_utils module location
fd -type f -name "cutlass_utils.py"

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

# Also search for HardwareInfo in pyx/pxd files (Cython)
fd -type f \( -name "*.pyx" -o -name "*.pxd" \) | xargs rg -l "HardwareInfo" 2>/dev/null | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 559


🏁 Script executed:

# Find files containing HardwareInfo class definition
rg "class HardwareInfo" --type py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check imports in gdn_prefill.py to locate cutlass_utils
head -30 flashinfer/gdn_kernels/blackwell/gdn_prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1080


🏁 Script executed:

# Continue reading gdn_prefill.py imports
head -60 flashinfer/gdn_kernels/blackwell/gdn_prefill.py | tail -30

Repository: flashinfer-ai/flashinfer

Length of output: 891


🏁 Script executed:

# Search for get_max_active_clusters method
rg "get_max_active_clusters" -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 8677


🏁 Script executed:

# Search for get_num_sm function definition and usage
rg "def get_num_sm|get_num_sm" -B 2 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 9982


🏁 Script executed:

# Check mla_decode imports more carefully
rg "from.*import|import" flashinfer/mla/cute_dsl/mla_decode.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 446


🏁 Script executed:

# Check what device is available in gdn_prefill.py context at lines 127-129
sed -n '110,150p' flashinfer/gdn_kernels/blackwell/gdn_prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 1615


🏁 Script executed:

# Verify GatedDeltaNetChunkedKernel use of num_sm
sed -n '215,235p' flashinfer/gdn_kernels/blackwell/gdn_prefill.py

Repository: flashinfer-ai/flashinfer

Length of output: 681


num_sm should be computed using device SM count, not get_max_active_clusters(1).

Line 128 incorrectly assigns num_sm = hardware_info.get_max_active_clusters(1), which returns the max active clusters for a cluster size of 1—not the actual SM count. This value is then passed to GatedDeltaNetChunkedKernel (line 139) and used to calculate workspace size (line 221). The correct approach, used throughout the codebase (e.g., mxfp4_quantize.py, nvfp4_quantize.py, mla_decode.py), is to call get_num_sm(q.device), which returns torch.cuda.get_device_properties(q.device).multi_processor_count.

Import get_num_sm from flashinfer.cute_dsl.utils and replace line 128 with num_sm = get_num_sm(q.device).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gdn_kernels/blackwell/gdn_prefill.py` around lines 127 - 129,
Replace the incorrect SM count retrieval: instead of calling
HardwareInfo.get_max_active_clusters(1) for num_sm, import and call
get_num_sm(q.device) from flashinfer.cute_dsl.utils and use that value for
num_sm passed into GatedDeltaNetChunkedKernel and workspace size calculations;
remove or leave HardwareInfo usage only for max_active_clusters as needed and
ensure get_max_active_clusters(1) is not used to compute num_sm.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
benchmarks/bench_blackwell_gdn_prefill.py (2)

341-344: Add an explicit CUDA-availability guard for clearer failure mode.

If CUDA is unavailable, the current flow may fail with a less actionable message before the SM100 check.

🔧 Proposed fix
 def main():
@@
-    device = torch.device("cuda")
+    if not torch.cuda.is_available():
+        print("Error: CUDA is not available.")
+        sys.exit(1)
+    device = torch.device("cuda")
     if not is_sm100a_supported(device):
         print("Error: This benchmark requires SM100+ (Blackwell) GPU.")
         sys.exit(1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 341 - 344, Add an
explicit CUDA availability check before selecting a CUDA device: call
torch.cuda.is_available() and if it returns False print a clear error and
sys.exit(1) before creating device or invoking is_sm100a_supported; ensure the
guard is placed prior to the line that sets device = torch.device("cuda") so
that subsequent calls like is_sm100a_supported(device) only run when CUDA is
present.

294-295: Avoid blind Exception catches in sweep loops.

Line 294 and Line 315 swallow all exceptions, including interruption/system-level signals, and make failures harder to triage.

🔧 Proposed fix
-        except Exception as e:
+        except KeyboardInterrupt:
+            raise
+        except RuntimeError as e:
             print(f" FAILED: {e}")
@@
-        except Exception as e:
+        except KeyboardInterrupt:
+            raise
+        except RuntimeError as e:
             print(f" FAILED: {e}")

Also applies to: 315-316

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_blackwell_gdn_prefill.py` around lines 294 - 295, The broad
"except Exception as e" handlers in the sweep loop should not swallow
system-level interrupts; update the two except blocks that currently read
"except Exception as e" so they re-raise KeyboardInterrupt and SystemExit (e.g.,
if isinstance(e, (KeyboardInterrupt, SystemExit)): raise) and only handle other
exceptions by logging the error and traceback for triage, rather than silently
printing; apply this change to both occurrences so interruption signals
propagate and failures are logged with full context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@benchmarks/bench_blackwell_gdn_prefill.py`:
- Around line 341-344: Add an explicit CUDA availability check before selecting
a CUDA device: call torch.cuda.is_available() and if it returns False print a
clear error and sys.exit(1) before creating device or invoking
is_sm100a_supported; ensure the guard is placed prior to the line that sets
device = torch.device("cuda") so that subsequent calls like
is_sm100a_supported(device) only run when CUDA is present.
- Around line 294-295: The broad "except Exception as e" handlers in the sweep
loop should not swallow system-level interrupts; update the two except blocks
that currently read "except Exception as e" so they re-raise KeyboardInterrupt
and SystemExit (e.g., if isinstance(e, (KeyboardInterrupt, SystemExit)): raise)
and only handle other exceptions by logging the error and traceback for triage,
rather than silently printing; apply this change to both occurrences so
interruption signals propagate and failures are logged with full context.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6e8a0c50-1328-4674-88ba-c521fef0fbb9

📥 Commits

Reviewing files that changed from the base of the PR and between 8b88000 and 2ec880b.

📒 Files selected for processing (1)
  • benchmarks/bench_blackwell_gdn_prefill.py

@jiahanc
Copy link
Copy Markdown
Collaborator Author

jiahanc commented Apr 13, 2026

tests look fine. pls address pre-commit check, which blocks merging

pre-commits failed at files not changed in this PR. Wonder if we should fix in separate PR

@jiahanc
Copy link
Copy Markdown
Collaborator Author

jiahanc commented Apr 13, 2026

tests look fine. pls address pre-commit check, which blocks merging

pre-commits failed at files not changed in this PR. Wonder if we should fix in separate PR

tests look fine. pls address pre-commit check, which blocks merging

pre-commits failed at files not changed in this PR. Wonder if we should fix in separate PR

@aleozlx, fixed in #3040

jiahanc added 12 commits April 13, 2026 07:29
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@kahyunnam kahyunnam merged commit 7c562d5 into flashinfer-ai:main Apr 13, 2026
32 checks passed
rtol_kv = 1e-3
else:
atol_o = 1e-3
atol_o = 2e-3
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello,Excellent work! I have a question that has been puzzling me: what causes the error to increase? I experimented with the #2742 branch with TF32 inverse implementation, and for long input sequences, the results aligned much more closely with the flash-linear-attention implementation. If we disregard potential FP16 numerical overflows during the inversion process, the computational precision of TF32 is theoretically equivalent to that of FP16. Could this discrepancy be attributed to the use of finer partitioning granularity (8x8 -> 16x16 -> 32x32 -> 64x64) during the inversion calculation?

Furthermore, ,QKV is obtained from qkv_factory,the input range for most Q, K, and V tensors appears to fall within [-0.4, 0.4]. Is this numerical range sufficient? For K, after applying L2 normalization, the data range becomes [-1, 1]. In Qwen3.5, Q also undergoes L2 normalization; consequently, when these inputs are fed into the gated_delta_rule computation, their range effectively becomes [-1, 1].

Looking forward to your reply,thanks in advance.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bestzsq The inversion unfortunately propagate error from diagonal (8x8) to lower left conner due to repeatedly truncation from fp32 acc to fp16 operand. Previously, blackwell version has a 128(Q)x128(K/V) block size config, this might be the root cause.

Blackwell kernel copied my hopper inversion impl strategy. I have experimented 3xFP16 and 3xTF32 inversion on hopper. They are much more accurate as you can store trunction error the second FP16 value. But due to the large kernel performance panelty (on Hopper), they are not upstreamed to FI. We may someday upstream it if proven to be needed.

Copy link
Copy Markdown

@bestzsq bestzsq Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guangyunh-nv Thanks for your reply. If the inversion calculation were to begin with 16x16 diagonal blocks, followed by 32x32 blocks, and finally 64x64 blocks, can the calculation error be reduced?

I experimented with modifying the matrix multiplication within the inversion process of the Triton chunk implementation in flash-linear-attention, changing the inputs to FP16 and the outputs to FP32. Although this involved some truncation from FP32 to FP16, the results demonstrated excellent consistency when compared to the unmodified Triton chunk implementation. However, in comparison to the current CuTe DSL implementation, triton chunk implementation utilizes sixteen 16x16 matrix multiplications for the non-diagonal blocks.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem with diagonal block processing (aka, substitution) is, it requires O(n^2) FMA to compute, that indicates O(n^2) number of instructions. So the smaller the better. 8x8 is a sweet spot as we can start to use Ampere style TC immediately after that.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Blackwell, due to its asynchrony nature, it is worth to explore it a little bit further. I tried 4x4 as start point on Hopper, but no further improvement. If Blackwell can tolerate 16x16 as its start point with no obvious perf panelty, I think it should be made a configurable parameter.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Observer007 Can the provided example be reproduced? I look forward to receiving any updates regarding this issue, thanks in advance!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can reproduce it now.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, do you think the precison of chunk size 128 is good enough or not?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bestzsq It is really a deeply hidden function bug, we have a fix in #3156 . After the fix, the mae of chunk size 128 and chunk size 64 are the same using your reproducer. Does the mae look good to you? Anyway, thanks again for the thorough inspection! And thanks for the explanation from @guangyunh-nv .

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Observer007 Sorry for the late reply and thanks for the fix! I have tested it on some inputs, and the discrepancies between cute_dsl (chunk size 64), flash-linear-attention's chunk_gated_delta_rule, and fused_recurrent_gated_delta_rule are all within the same order of magnitude. I think the precison of chunk size 128/64 is good enough.

kahyunnam pushed a commit that referenced this pull request Apr 24, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

Fixed the accuracy issue in blackwell gdn kernel found by @bestzsq. 

The root cause is that the legacy `max_coord` is not the actual last
coord of the `sCumprod`. We change to the last coord instead. It's a
deeply hidden bug that we hadn't discovered previously. Thanks to
@bestzsq.


Reproducer test link from @bestzsq:
#3001 (comment)

Reproducer test output before this pr:
```
# flash-linear-attention==0.4.2
fla vs cute64: mae: 2.82288e-03, ulp: 9040.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0
# flash-linear-attention==0.5.0
fla vs cute64: mae: 2.82288e-03, ulp: 9064.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0
```

Reproducer test output after this pr:
```
# flash-linear-attention==0.4.2
fla vs cute64: mae: 3.05176e-05, ulp: 74.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0
# flash-linear-attention==0.5.0
fla vs cute64: mae: 3.05176e-05, ulp: 74.0
fla vs cute128: mae: 3.05176e-05, ulp: 74.0
```

Previous local test tolerance loosen from `1e-3` to `2e-3` in #3001 :
https://github.com/flashinfer-ai/flashinfer/pull/3001/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6L132

This pr tightenes the tolerance from `2e-3` to `1e-3`:
https://github.com/flashinfer-ai/flashinfer/pull/3156/changes#diff-d3d322c588f461c03200b8a16ce676dbcab99e11a6b225d230ea0d51c9e8dbf6R148

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **Refactor**
* Improved kernel computation efficiency by consolidating internal
calculation steps and removing redundant intermediate operations,
reducing code complexity while preserving all existing functionality and
performance characteristics.

* **Tests**
* Strengthened numerical validation by reducing tolerance thresholds in
computational accuracy tests for greater precision, ensuring more
stringent verification of output correctness and numerical consistency
across test scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
kahyunnam pushed a commit that referenced this pull request Apr 25, 2026
…3155)

## 📌 Description

Fixes the `num_sm` issue CodeRabbit flagged on #3001 but which was not
applied before merge:
#3001 (comment)

The raw `HardwareInfo().get_max_active_clusters(1)` call returns 0 /
stale values in spawned subprocesses (e.g. vLLM's EngineCore workers)
where the CUDA driver API context has not been made current yet. The
persistent tile scheduler then leaves some CTAs without any work and the
kernel deadlocks at first call. Switch to `get_num_sm(q.device)`,
matching the SM120 MoE dispatch.

## 🔍 Related Issues

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Kernel compilation now derives device-specific SM and cluster counts
at runtime, improving GPU resource allocation and leading to more
consistent performance across different CUDA devices.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
kahyunnam pushed a commit that referenced this pull request Apr 27, 2026
## 📌 Description

Addresses the two remaining CodeRabbit findings on
[#3001](#3001) that
weren't applied before merge:

* **Normalize `scale=0.0` to the default `1/sqrt(d_k)`** before backend
dispatch so the same call gives matching numerics on SM90 and SM100. The
SM90 C++ kernel treats `0.0` as a sentinel for "use default", but the
SM100 CuTe-DSL kernel forwarded the literal `0.0` → zeroed QK → broken
attention.

* **Don't eagerly allocate `output_state`** on the SM100 path when
`output_final_state=False`. The CuTe-DSL kernel drops the buffer anyway,
so the old code wasted a full `[num_seqs, H, 128, 128]` float32 scratch
per call. SM90 still allocates unconditionally because its C++ kernel
always writes into `output_state`.

Dispatcher callsites now pass `output_state` directly on both branches
(no inline `output_state if output_final_state else None`), so SM90 and
SM100 read identically.


## 🔍 Related Issues

* [[feat] Add blackwell GDN prefill
kernel](#3001)
* [fix(gdn): use physical SM count for SM100 persistent prefill
kernel#3155](#3155)
* [[fix] fix blackwell gdn accuracy
issue#3156](#3156)

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Bug Fixes**
* Fixed scale parameter handling to correctly interpret explicit values
and apply default scaling behavior.
* Improved memory efficiency by avoiding unnecessary state allocations
in certain configurations.

* **Improvements**
* Enhanced consistency in kernel invocation logic across different
hardware architectures.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants