[GDN] Enable FI Blackwell GDN prefill kernel#40717
[GDN] Enable FI Blackwell GDN prefill kernel#40717arpera wants to merge 3 commits intovllm-project:mainfrom
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request adds support for FlashInfer's Blackwell SM100 GDN prefill kernel by introducing the nvidia-cutlass-dsl dependency for CUDA 13 builds and implementing logic to select the appropriate backend based on hardware and software requirements. Feedback suggests adding an explicit CUDA version check in the backend selection logic to align with the stated requirements and restoring a warning message for cases where the user-requested FlashInfer backend cannot be used.
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
57a19c7 to
560797f
Compare
|
Hi @arpera, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
|
@arpera Thanks for the effort, I take a try and encounter the following error, I'm wondering if you have any ideas about it? |
|
@sighingnow, first of all, do you use this patch for FI flashinfer-ai/flashinfer#3155? |
|
@arpera I have managed to resolve the problem with the following change to flashinfer: b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
index 53fe44ce..2c22c8e1 100644
--- a/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
+++ b/flashinfer/gdn_kernels/blackwell/gated_delta_net_chunked.py
@@ -2406 +2406 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tKKrKK.shape[2]):
+ for sub in cutlass.range_constexpr(tKKrKK.shape[2]):
@@ -2446 +2446 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tQKrQK.shape[2]):
+ for sub in cutlass.range_constexpr(tQKrQK.shape[2]):
@@ -2982 +2982 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tRT_tCrState.shape[2]):
+ for sub in cutlass.range_constexpr(tRT_tCrState.shape[2]):
@@ -3066 +3066 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tTR_rState.shape[2]):
+ for sub in cutlass.range_constexpr(tTR_rState.shape[2]):
@@ -3347 +3347 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tRT_rState_inp.shape[2]):
+ for sub in cutlass.range_constexpr(tRT_rState_inp.shape[2]):
@@ -3389 +3389 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tTR_rState.shape[2]):
+ for sub in cutlass.range_constexpr(tTR_rState.shape[2]):
@@ -3474 +3474 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tTR_rQS.shape[1]):
+ for sub in cutlass.range_constexpr(tTR_rQS.shape[1]):
@@ -3502 +3502 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tTR_rNv.shape[1]):
+ for sub in cutlass.range_constexpr(tTR_rNv.shape[1]):
@@ -3526 +3526 @@ class GatedDeltaNetChunkedKernel:
- for sub in cutlass.range(tTR_rDv.shape[1]):
+ for sub in cutlass.range_constexpr(tTR_rDv.shape[1]): |
|
I’m still interested in the following: have you used this patch for Flashinfer (flashinfer-ai/flashinfer#3155)? |
This patch was already included. |
|
Then could you please show an example how to reproduce the problem that you reported? |
IMPORTANT!!!
This PR MUST be merged after this change flashinfer-ai/flashinfer#3155 is merged in Flashinfer and vLLM starts to use this FI version. There is a bug in GDN implementation in FI.
Purpose
Enable FlashInfer's new Blackwell (SM100) CuTe-DSL GDN prefill kernel (flashinfer-ai/flashinfer#3001) by default in vLLM.
The same PR Add FlashInfer prefill support for SM100+ in sqlang just in case.
On Blackwell the dispatcher in
ChunkGatedDeltaRule.__init__now routes GDN prefill to FlashInfer when all of the following hold (logged once at init):requested in ["flashinfer", "auto"];platform == cuda;head_k_dim == 128,nvidia-cutlass-dsl-libs-cu13installed,cuda_runtime >= 13.Otherwise we stay on the Triton/FLA path.
Test Result
Hardware: 8xB200
Functional
e2e gsm8k:
Accuracy remains the same, no degradation.
Performance
GDN prefill kernel micro-benchmark:
gdn_prefill_bench.py
FlashInfer Blackwell SM100 vs FLA/Triton on B200 across Qwen3.5 configurations — speedup ranges from 1.01× (TP8, small heads, small seqlen) to 5.46× (TP1, full head count, balanced split).
Full table
e2e prefill-only benchmark:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.