[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921
[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921kaixih wants to merge 3 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
The model has a repeated block pattern of 3× linear attention (GDN) + 1× full attention.
The GDN kernel itself is ~19% faster with FlashInfer; the modest system-level gain (~5%) FlashInfer GDN prefill — kernel breakdown (per layer, 11 launches)
Triton GDN prefill — kernel breakdown (per layer, 12 launches)
The ~80 µs gap between summed kernel times and wall time reflects Python-level kernel |
|
This PR is ready for review. |
|
The CuteDSL kernel performance is limited by low parallelism when batch size and number of heads are small, which is clearly shown by the kernel benchmark in flashinfer-ai/flashinfer#3001 Depending on how the prefill benchmark is configured, the e2e speedup will vary a lot. For example, for 1k or 8k ISL and --chunked-prefill-size 163840, and TP4, you get effect batch size 160 and 20 and will hit the higher end of the speedup. But if you set --chunked-prefill-size 8192, the effective batch size will be smaller and will hit the lower end of the speedup. In practice, the real speedup will depend on the real ISL of the workloads, and we likely won't see much speedup for the long ISL workloads. |
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
23b04c0 to
b6c0d39
Compare
| q_fi = l2norm_fwd(q[0].contiguous()) | ||
| k_fi = l2norm_fwd(k[0].contiguous()) |
There was a problem hiding this comment.
We can modify the triton l2norm_fwd kernel to make it support strided inputs to eliminate the contiguous calls
|
/tag-and-rerun-ci |
|
/rerun-failed-ci |
| # SM100+ FlashInfer GDN prefill requires CUDA 13+ (CuTe DSL kernel) | ||
| # for correctness and best performance. | ||
| prefill = self.linear_attn_prefill_backend or self.linear_attn_backend | ||
| if ( |
There was a problem hiding this comment.
We'd better add bf16 state dtype validation for SM100+ FlashInfer prefill backend, just like how SM100+ FlashInfer decode backend does:
if (
decode == "flashinfer"
and self.mamba_ssm_dtype != "bfloat16"
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 10
):
Otherwise, the user can then run SM100+ FlashInfer prefill with float32 state, which is unsupported (the module docstring states "SM100+: decode and prefill with bf16 state"), likely causing kernel errors or incorrect results at runtime.
There was a problem hiding this comment.
the flashinfer prefill kernel actually supports the fp32. so the current status from flashinfer:
prefill: fp32/bf16
decode: bf16
Note, here I am talking about the "fast" kernels that we recommended for the blackwell (there are some "legacy" kernels that are not the focus of this PR).
So, the below is what is going to happen with the current code:
# if users use fp32 states
perfill works but decode will complain
# if users use bf16 states
both perfill and decode work
|
/rerun-failed-ci |
|
/rerun-failed-ci |
[GDN] Add FlashInfer prefill support for SM100+ (Blackwell)
Summary
Extends FlashInfer GDN kernel support to cover the prefill/extend path on SM100+
(Blackwell) hardware, previously raising
NotImplementedError. SM90 (Hopper)prefill was already supported; this PR completes SM100+ coverage.
Accuracy (Qwen3.5-397B-A17B-NVFP4, B200)
gsm8k (200 examples, baseline threshold: 0.95)
GPQA diamond (198 examples, repeat=8, temperature=0.6)
Throughput Benchmark (B200, Qwen3.5-397B-A17B-NVFP4, TP=8)
More detailed perf numbers in the PR comments below.
Server settings:
--tp-size 8 --max-running-requests 256 --chunked-prefill-size 163840--mamba-ssm-dtype bfloat16 --mamba-scheduler-strategy no_buffer --mamba-track-interval 128--attention-backend trtllm_mha --linear-attn-decode-backend flashinfer--linear-attn-prefill-backend <triton|flashinfer>(varied per run)--disable-radix-cache --quantization modelopt_fp4Benchmark settings:
--dataset-name random --random-input-len 8192 --random-output-len 128--max-concurrency 256 --num-prompts 512Requirements
chunk_gated_delta_ruleSM100 path)nvidia-cutlass-dsl[cu13] >= 4.4.2(SM100+ only)_cuda_major >= 13)