Skip to content

[GDN] Enable FI Blackwell GDN prefill kernel#40717

Open
arpera wants to merge 3 commits intovllm-project:mainfrom
arpera:fl-backwell-gdn-prefill
Open

[GDN] Enable FI Blackwell GDN prefill kernel#40717
arpera wants to merge 3 commits intovllm-project:mainfrom
arpera:fl-backwell-gdn-prefill

Conversation

@arpera
Copy link
Copy Markdown
Contributor

@arpera arpera commented Apr 23, 2026

IMPORTANT!!!

This PR MUST be merged after this change flashinfer-ai/flashinfer#3155 is merged in Flashinfer and vLLM starts to use this FI version. There is a bug in GDN implementation in FI.

Purpose

Enable FlashInfer's new Blackwell (SM100) CuTe-DSL GDN prefill kernel (flashinfer-ai/flashinfer#3001) by default in vLLM.
The same PR Add FlashInfer prefill support for SM100+ in sqlang just in case.

On Blackwell the dispatcher in ChunkGatedDeltaRule.__init__ now routes GDN prefill to FlashInfer when all of the following hold (logged once at init):

  • requested in ["flashinfer", "auto"];
  • platform == cuda;
  • one of:
    • Hopper (SM90) — no further constraints (pre-existing path, unchanged);
    • Blackwell (SM10.x) with head_k_dim == 128, nvidia-cutlass-dsl-libs-cu13 installed, cuda_runtime >= 13.

Otherwise we stay on the Triton/FLA path.

Test Result

Hardware: 8xB200

Functional

e2e gsm8k:

# Server
vllm serve nvidia/Qwen3.5-397B-A17B-NVFP4 \
    --port 8000 -tp 1 -pp 1 -dp 8 \
    --enable-expert-parallel \
    --language-model-only \
    --reasoning-parser qwen3 \
    --stream-interval 100 \
    --gdn-prefill-backend {triton,flashinfer}
    
# Client
python3 tests/evals/gsm8k/gsm8k_eval.py
Backend Accuracy Invalid Rate Tokens/sec
Triton 0.873 0.027 1,290
FlashInfer Blackwell 0.874 0.024 1,611

Accuracy remains the same, no degradation.

Performance

GDN prefill kernel micro-benchmark:

gdn_prefill_bench.py

python3 gdn_prefill_bench.py

FlashInfer Blackwell SM100 vs FLA/Triton on B200 across Qwen3.5 configurations — speedup ranges from 1.01× (TP8, small heads, small seqlen) to 5.46× (TP1, full head count, balanced split).

Full table
GPU: NVIDIA B200 [Blackwell (SM100)]
Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128

Heads            Seqlens           h_qk  h_v    FI Blackwell (SM100)  FLA/Triton   Speedup
------------------------------------------------------------------------------------------
397B/122B TP8    1x8192               2    8                  0.333ms      0.335ms     1.01x
397B/122B TP8    1x4096               2    8                  0.177ms      0.222ms     1.26x
397B/122B TP8    1x2048               2    8                  0.099ms      0.223ms     2.26x
397B/122B TP8    6144+2048            2    8                  0.255ms      0.282ms     1.11x
397B/122B TP8    4096+4096            2    8                  0.177ms      0.230ms     1.30x
397B/122B TP8    2048+6144            2    8                  0.255ms      0.283ms     1.11x
397B/122B TP8    1024+7168            2    8                  0.294ms      0.308ms     1.05x
397B/122B TP8    2048x4               2    8                  0.100ms      0.240ms     2.41x
397B/122B TP8    1024x8               2    8                  0.062ms      0.242ms     3.87x

397B/122B TP4    1x8192               4   16                  0.333ms      0.417ms     1.25x
397B/122B TP4    1x4096               4   16                  0.177ms      0.250ms     1.41x
397B/122B TP4    1x2048               4   16                  0.099ms      0.251ms     2.53x
397B/122B TP4    6144+2048            4   16                  0.255ms      0.363ms     1.43x
397B/122B TP4    4096+4096            4   16                  0.178ms      0.311ms     1.75x
397B/122B TP4    2048+6144            4   16                  0.256ms      0.365ms     1.43x
397B/122B TP4    1024+7168            4   16                  0.295ms      0.391ms     1.32x
397B/122B TP4    2048x4               4   16                  0.101ms      0.281ms     2.78x
397B/122B TP4    1024x8               4   16                  0.066ms      0.280ms     4.25x

397B/122B TP2    1x8192               8   32                  0.335ms      0.583ms     1.74x
397B/122B TP2    1x4096               8   32                  0.178ms      0.315ms     1.77x
397B/122B TP2    1x2048               8   32                  0.099ms      0.249ms     2.50x
397B/122B TP2    6144+2048            8   32                  0.256ms      0.552ms     2.15x
397B/122B TP2    4096+4096            8   32                  0.180ms      0.522ms     2.91x
397B/122B TP2    2048+6144            8   32                  0.259ms      0.594ms     2.30x
397B/122B TP2    1024+7168            8   32                  0.298ms      0.633ms     2.13x
397B/122B TP2    2048x4               8   32                  0.105ms      0.517ms     4.90x
397B/122B TP2    1024x8               8   32                  0.122ms      0.523ms     4.29x

397B/122B TP1    1x8192              16   64                  0.337ms      0.963ms     2.86x
397B/122B TP1    1x4096              16   64                  0.179ms      0.504ms     2.81x
397B/122B TP1    1x2048              16   64                  0.101ms      0.278ms     2.75x
397B/122B TP1    6144+2048           16   64                  0.261ms      0.965ms     3.70x
397B/122B TP1    4096+4096           16   64                  0.185ms      0.971ms     5.24x
397B/122B TP1    2048+6144           16   64                  0.263ms      0.975ms     3.70x
397B/122B TP1    1024+7168           16   64                  0.302ms      0.978ms     3.24x
397B/122B TP1    2048x4              16   64                  0.201ms      0.978ms     4.86x
397B/122B TP1    1024x8              16   64                  0.234ms      0.959ms     4.10x

35B/9B/4B TP1    1x8192              16   32                  0.335ms      0.586ms     1.75x
35B/9B/4B TP1    1x4096              16   32                  0.179ms      0.318ms     1.78x
35B/9B/4B TP1    1x2048              16   32                  0.100ms      0.251ms     2.51x
35B/9B/4B TP1    6144+2048           16   32                  0.257ms      0.555ms     2.16x
35B/9B/4B TP1    4096+4096           16   32                  0.181ms      0.525ms     2.90x
35B/9B/4B TP1    2048+6144           16   32                  0.259ms      0.595ms     2.30x
35B/9B/4B TP1    1024+7168           16   32                  0.300ms      0.637ms     2.13x
35B/9B/4B TP1    2048x4              16   32                  0.105ms      0.521ms     4.94x
35B/9B/4B TP1    1024x8              16   32                  0.122ms      0.525ms     4.30x

27B TP1          1x8192              16   48                  0.333ms      0.801ms     2.40x
27B TP1          1x4096              16   48                  0.177ms      0.426ms     2.41x
27B TP1          1x2048              16   48                  0.100ms      0.252ms     2.53x
27B TP1          6144+2048           16   48                  0.257ms      0.759ms     2.95x
27B TP1          4096+4096           16   48                  0.181ms      0.721ms     3.98x
27B TP1          2048+6144           16   48                  0.259ms      0.763ms     2.95x
27B TP1          1024+7168           16   48                  0.298ms      0.783ms     2.63x
27B TP1          2048x4              16   48                  0.198ms      0.718ms     3.62x
27B TP1          1024x8              16   48                  0.178ms      0.713ms     4.01x

2B/0.8B TP1      1x8192              16   16                  0.330ms      0.425ms     1.29x
2B/0.8B TP1      1x4096              16   16                  0.175ms      0.262ms     1.49x
2B/0.8B TP1      1x2048              16   16                  0.099ms      0.251ms     2.53x
2B/0.8B TP1      6144+2048           16   16                  0.253ms      0.371ms     1.47x
2B/0.8B TP1      4096+4096           16   16                  0.178ms      0.318ms     1.79x
2B/0.8B TP1      2048+6144           16   16                  0.253ms      0.372ms     1.47x
2B/0.8B TP1      1024+7168           16   16                  0.297ms      0.398ms     1.34x
2B/0.8B TP1      2048x4              16   16                  0.101ms      0.284ms     2.80x
2B/0.8B TP1      1024x8              16   16                  0.066ms      0.287ms     4.35x

Sym h32          1x8192              32   32                  0.335ms      0.594ms     1.77x
Sym h32          1x4096              32   32                  0.179ms      0.322ms     1.80x
Sym h32          1x2048              32   32                  0.100ms      0.249ms     2.48x
Sym h32          6144+2048           32   32                  0.257ms      0.559ms     2.18x
Sym h32          4096+4096           32   32                  0.180ms      0.523ms     2.91x
Sym h32          2048+6144           32   32                  0.258ms      0.591ms     2.29x
Sym h32          1024+7168           32   32                  0.296ms      0.629ms     2.13x
Sym h32          2048x4              32   32                  0.106ms      0.524ms     4.94x
Sym h32          1024x8              32   32                  0.122ms      0.531ms     4.35x

e2e prefill-only benchmark:

# Server
vllm serve nvidia/Qwen3.5-397B-A17B-NVFP4 \
    --port 8000 -tp 1 -pp 1 -dp 8 \
    --enable-expert-parallel \
    --language-model-only \
    --reasoning-parser qwen3 \
    --stream-interval 100 \
    --gdn-prefill-backend {triton,flashinfer}

# Client, 3 runs
vllm bench serve \
    --backend vllm \
    --model nvidia/Qwen3.5-397B-A17B-NVFP4 \
    --port 8000 \
    --endpoint /v1/completions \
    --dataset-name random \
    --random-input 8192 \
    --random-output 1 \
    --max-concurrency 128 \
    --num-prompt 128 \
    --ignore-eos \
    --temperature 0.0
Metric Triton (3 runs) FlashInfer Blackwell (3 runs)
Benchmark duration (s) 6.91 / 6.15 / 6.19 6.26 / 5.44 / 5.51
Request throughput (req/s) 18.52 / 20.83 / 20.69 20.46 / 23.54 / 23.22
Total token throughput (tok/s) 151735 / 170623 / 169486 167635 / 192878 / 190251
Mean TTFT (ms) 3789.0 / 3302.5 / 3335.7 3596.4 / 2915.6 / 2942.1
Median TTFT (ms) 3783.0 / 3310.6 / 3321.7 3551.1 / 2906.7 / 2926.8
P99 TTFT (ms) 6638.3 / 6093.1 / 5938.7 6154.4 / 5344.4 / 5461.4
  • Throughput 1.13×: 20.83 → 23.54 req/s, 170,623 → 192,878 tok/s
  • Mean TTFT −12%: 3302 ms → 2916 ms

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the ci/build label Apr 23, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @arpera.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for FlashInfer's Blackwell SM100 GDN prefill kernel by introducing the nvidia-cutlass-dsl dependency for CUDA 13 builds and implementing logic to select the appropriate backend based on hardware and software requirements. Feedback suggests adding an explicit CUDA version check in the backend selection logic to align with the stated requirements and restoring a warning message for cases where the user-requested FlashInfer backend cannot be used.

Comment thread vllm/model_executor/layers/mamba/gdn_linear_attn.py Outdated
Comment thread vllm/model_executor/layers/mamba/gdn_linear_attn.py
arpera added 2 commits April 23, 2026 21:16
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
@arpera arpera force-pushed the fl-backwell-gdn-prefill branch from 57a19c7 to 560797f Compare April 23, 2026 18:19
@mergify mergify Bot removed the needs-rebase label Apr 23, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

Hi @arpera, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
@vadiklyutiy vadiklyutiy moved this to In review in Qwen3.5 Apr 24, 2026
@sighingnow
Copy link
Copy Markdown
Collaborator

@arpera Thanks for the effort, I take a try and encounter the following error, I'm wondering if you have any ideas about it?

[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600] Traceback (most recent call last):
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]   File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/dsl.py", line 980, in compile_and_jit
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]     kernel = self.compiler_provider.compile_and_jit(
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]   File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 186, in compile_and_jit
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]     self.compile(
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]   File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 161, in compile
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]     raise e
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]   File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/compiler.py", line 148, in compile
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]     pm.run(module.operation)
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]     │      └ <cutlass._mlir._mlir_libs._cutlass_ir._mlir.ir.Module object at 0x7f00c02babb0>
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]     └ <cutlass._mlir._mlir_libs._cutlass_ir._mlir.passmanager.PassManager object at 0x7f05d0171790>
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600] cutlass._mlir._mlir_libs._site_initialize.<locals>.MLIRError: Failure while executing pass pipeline:
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600] error: "tiled_state_t2r = tcgen05.make_tmem_copy(atom_state_t2r, tCtState_for_t2r)"("/usr/local/lib/python3.12/dist-packages/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py":3176:22): failed to legalize unresolved materialization from ('!cute_nvgpu.atom.tmem_load<f32, 32 DP, 32 bit, x32>') to ('!cute.tiled_copy<!cute_nvgpu.atom.tmem_load<f32, 32 DP, 32 bit, x32>, layout_copy_tv = <"((32,4),(32,32)):((0,1),(128,4))">, tiler_mn = <"[(4,32):(32,1);32:1]">>') that remained live after conversion
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]  note: "tiled_state_t2r = tcgen05.make_tmem_copy(atom_state_t2r, tCtState_for_t2r)"("/usr/local/lib/python3.12/dist-packages/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py":3176:22): see current operation: %12086 = "builtin.unrealized_conversion_cast"(%12085) : (!cute_nvgpu.atom.tmem_load<f32, 32 DP, 32 bit, x32>) -> !cute.tiled_copy<!cute_nvgpu.atom.tmem_load<f32, 32 DP, 32 bit, x32>, layout_copy_tv = <"((32,4),(32,32)):((0,1),(128,4))">, tiler_mn = <"[(4,32):(32,1);32:1]">>
[1/4,TP1][pid=879630] ERROR 04-26 09:35:18.710 [forward_context.py:600]  note: unknown: see existing live user here: 

@arpera
Copy link
Copy Markdown
Contributor Author

arpera commented Apr 26, 2026

@sighingnow, first of all, do you use this patch for FI flashinfer-ai/flashinfer#3155?
If yes, then could you please share more information about your hardware, environment, and some reproducible example?

@sighingnow
Copy link
Copy Markdown
Collaborator

@arpera I have managed to resolve the problem with the following change to flashinfer:

b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 53fe44ce..2c22c8e1 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -2406 +2406 @@ class GatedDeltaNetChunkedKernel:
-        for sub in cutlass.range(tKKrKK.shape[2]):
+        for sub in cutlass.range_constexpr(tKKrKK.shape[2]):
@@ -2446 +2446 @@ class GatedDeltaNetChunkedKernel:
-        for sub in cutlass.range(tQKrQK.shape[2]):
+        for sub in cutlass.range_constexpr(tQKrQK.shape[2]):
@@ -2982 +2982 @@ class GatedDeltaNetChunkedKernel:
-        for sub in cutlass.range(tRT_tCrState.shape[2]):
+        for sub in cutlass.range_constexpr(tRT_tCrState.shape[2]):
@@ -3066 +3066 @@ class GatedDeltaNetChunkedKernel:
-        for sub in cutlass.range(tTR_rState.shape[2]):
+        for sub in cutlass.range_constexpr(tTR_rState.shape[2]):
@@ -3347 +3347 @@ class GatedDeltaNetChunkedKernel:
-            for sub in cutlass.range(tRT_rState_inp.shape[2]):
+            for sub in cutlass.range_constexpr(tRT_rState_inp.shape[2]):
@@ -3389 +3389 @@ class GatedDeltaNetChunkedKernel:
-            for sub in cutlass.range(tTR_rState.shape[2]):
+            for sub in cutlass.range_constexpr(tTR_rState.shape[2]):
@@ -3474 +3474 @@ class GatedDeltaNetChunkedKernel:
-            for sub in cutlass.range(tTR_rQS.shape[1]):
+            for sub in cutlass.range_constexpr(tTR_rQS.shape[1]):
@@ -3502 +3502 @@ class GatedDeltaNetChunkedKernel:
-        for sub in cutlass.range(tTR_rNv.shape[1]):
+        for sub in cutlass.range_constexpr(tTR_rNv.shape[1]):
@@ -3526 +3526 @@ class GatedDeltaNetChunkedKernel:
-        for sub in cutlass.range(tTR_rDv.shape[1]):
+        for sub in cutlass.range_constexpr(tTR_rDv.shape[1]):

@arpera
Copy link
Copy Markdown
Contributor Author

arpera commented Apr 26, 2026

I’m still interested in the following: have you used this patch for Flashinfer (flashinfer-ai/flashinfer#3155)?
If yes, could you please share more details about your hardware, environment, and a reproducible example?

@sighingnow
Copy link
Copy Markdown
Collaborator

I’m still interested in the following: have you used this patch for Flashinfer (flashinfer-ai/flashinfer#3155)? If yes, could you please share more details about your hardware, environment, and a reproducible example?

This patch was already included.

@arpera
Copy link
Copy Markdown
Contributor Author

arpera commented Apr 27, 2026

Then could you please show an example how to reproduce the problem that you reported?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

3 participants