Skip to content

[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921

Open
kaixih wants to merge 3 commits intosgl-project:mainfrom
kaixih:add_flashinfer_gdn_prefill
Open

[NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell)#22921
kaixih wants to merge 3 commits intosgl-project:mainfrom
kaixih:add_flashinfer_gdn_prefill

Conversation

@kaixih
Copy link
Copy Markdown
Collaborator

@kaixih kaixih commented Apr 16, 2026

[GDN] Add FlashInfer prefill support for SM100+ (Blackwell)

Summary

Extends FlashInfer GDN kernel support to cover the prefill/extend path on SM100+
(Blackwell) hardware, previously raising NotImplementedError. SM90 (Hopper)
prefill was already supported; this PR completes SM100+ coverage.

Accuracy (Qwen3.5-397B-A17B-NVFP4, B200)

gsm8k (200 examples, baseline threshold: 0.95)

Backend Score
Triton (prefill + decode) 0.985
FlashInfer (prefill + decode) 0.985

GPQA diamond (198 examples, repeat=8, temperature=0.6)

Backend Scores Mean
FlashInfer (prefill + decode) 0.848, 0.879, 0.904, 0.879, 0.848, 0.864, 0.869, 0.869 0.870

Throughput Benchmark (B200, Qwen3.5-397B-A17B-NVFP4, TP=8)

More detailed perf numbers in the PR comments below.

Server settings:

  • --tp-size 8 --max-running-requests 256 --chunked-prefill-size 163840
  • --mamba-ssm-dtype bfloat16 --mamba-scheduler-strategy no_buffer --mamba-track-interval 128
  • --attention-backend trtllm_mha --linear-attn-decode-backend flashinfer
  • --linear-attn-prefill-backend <triton|flashinfer> (varied per run)
  • --disable-radix-cache --quantization modelopt_fp4

Benchmark settings:

  • --dataset-name random --random-input-len 8192 --random-output-len 128
  • --max-concurrency 256 --num-prompts 512
Metric Triton prefill FlashInfer prefill Speedup
Benchmark duration (s) 53.27 50.87 1.05x
Input throughput (tok/s) 78,734 82,445 1.05x
Total throughput (tok/s) 79,964 83,733 1.05x
Mean TTFT (ms) 12,742 12,042 1.06x
Mean TPOT (ms) 109.08 105.14 1.04x

Requirements

  • FlashInfer >= 0.6.8 (for chunk_gated_delta_rule SM100 path)
  • nvidia-cutlass-dsl[cu13] >= 4.4.2 (SM100+ only)
  • CUDA 13 (SM100+ path requires _cuda_major >= 13)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@kaixih
Copy link
Copy Markdown
Collaborator Author

kaixih commented Apr 16, 2026

cc @hlu1 @YAMY1234 @wenscarl

@kaixih kaixih changed the title [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) [Draft] [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) Apr 16, 2026
@kaixih kaixih changed the title [Draft] [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) [NVIDIA] [GDN] Add FlashInfer prefill support for SM100+ (Blackwell) Apr 16, 2026
@kaixih
Copy link
Copy Markdown
Collaborator Author

kaixih commented Apr 16, 2026

The model has a repeated block pattern of 3× linear attention (GDN) + 1× full attention.
Profiling one such block during prefill:

Backend Block wall time GDN prefill (3 layers) GDN per layer Kernels/layer
Triton 12,784 µs 1,518 µs (506×3) 506 µs 12
FlashInfer 12,379 µs 1,275 µs (425×3) 425 µs 11
Speedup 1.03x 1.19x 1.19x

The GDN kernel itself is ~19% faster with FlashInfer; the modest system-level gain (~5%)
reflects that GDN is a small fraction of the total forward pass (MoE GEMM, attention,
all-reduce account for the rest).

FlashInfer GDN prefill — kernel breakdown (per layer, 11 launches)

Kernel Calls Time
GatedDeltaNetChunkedKernel (fused main) 1 328.2 µs
elementwise_kernel (bf16 contiguity copy, packed QKV) 3 58.2 µs (19.4 µs each)
l2norm_fwd_kernel 2 7.5 µs (3.7 µs each)
index_elementwise_kernel (index_copy scatter) 1 2.9 µs
vectorized_gather_kernel (state gather) 1 2.5 µs
vectorized_elementwise_kernel (exp) 1 2.4 µs
unrolled_elementwise_kernel (int64 cast for index_copy) 1 2.2 µs
vectorized_elementwise_kernel (clamp) 1 2.0 µs
Total 11 ≈406 µs (wall: 425 µs)

Triton GDN prefill — kernel breakdown (per layer, 12 launches)

Kernel Calls Time
chunk_gated_delta_rule_fwd_kernel_h_blockdim64 (main recurrence) 1 257.9 µs
chunk_fwd_kernel_o (output projection) 1 63.5 µs
elementwise_kernel (bf16 contiguity copy, packed QKV) 3 56.8 µs (18.9 µs each)
chunk_gated_delta_rule_fwd_kkt_solve_kernel 1 42.2 µs
recompute_w_u_fwd_kernel 1 34.2 µs
vectorized_elementwise_kernel (fill bf16) 2 15.6 µs (7.8 µs each)
l2norm_fwd_kernel 2 9.0 µs (4.5 µs each)
chunk_local_cumsum_scalar_kernel 1 4.8 µs
Total 12 ≈484 µs (wall: 506 µs)

The ~80 µs gap between summed kernel times and wall time reflects Python-level kernel
launch overhead (gaps between dispatches). The FlashInfer overhead items above
(packed QKV copies, gather/scatter, l2norm, exp, cast, clamp — ~78 µs) are candidates
for elimination via the upstream improvements listed above.

@kaixih
Copy link
Copy Markdown
Collaborator Author

kaixih commented Apr 16, 2026

This PR is ready for review.

@hlu1
Copy link
Copy Markdown
Collaborator

hlu1 commented Apr 16, 2026

The CuteDSL kernel performance is limited by low parallelism when batch size and number of heads are small, which is clearly shown by the kernel benchmark in flashinfer-ai/flashinfer#3001

Depending on how the prefill benchmark is configured, the e2e speedup will vary a lot. For example, for 1k or 8k ISL and --chunked-prefill-size 163840, and TP4, you get effect batch size 160 and 20 and will hit the higher end of the speedup. But if you set --chunked-prefill-size 8192, the effective batch size will be smaller and will hit the lower end of the speedup. In practice, the real speedup will depend on the real ISL of the workloads, and we likely won't see much speedup for the long ISL workloads.

Comment thread python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
Comment thread python/sglang/srt/layers/attention/linear/kernels/gdn_flashinfer.py
@kaixih kaixih force-pushed the add_flashinfer_gdn_prefill branch from 23b04c0 to b6c0d39 Compare April 17, 2026 18:00
Comment on lines 194 to 195
q_fi = l2norm_fwd(q[0].contiguous())
k_fi = l2norm_fwd(k[0].contiguous())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can modify the triton l2norm_fwd kernel to make it support strided inputs to eliminate the contiguous calls

@ispobock
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

# SM100+ FlashInfer GDN prefill requires CUDA 13+ (CuTe DSL kernel)
# for correctness and best performance.
prefill = self.linear_attn_prefill_backend or self.linear_attn_backend
if (
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better add bf16 state dtype validation for SM100+ FlashInfer prefill backend, just like how SM100+ FlashInfer decode backend does:

        if (
            decode == "flashinfer"
            and self.mamba_ssm_dtype != "bfloat16"
            and torch.cuda.is_available()
            and torch.cuda.get_device_capability()[0] >= 10
        ):

Otherwise, the user can then run SM100+ FlashInfer prefill with float32 state, which is unsupported (the module docstring states "SM100+: decode and prefill with bf16 state"), likely causing kernel errors or incorrect results at runtime.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the flashinfer prefill kernel actually supports the fp32. so the current status from flashinfer:

prefill: fp32/bf16
decode: bf16

Note, here I am talking about the "fast" kernels that we recommended for the blackwell (there are some "legacy" kernels that are not the focus of this PR).

So, the below is what is going to happen with the current code:

# if users use fp32 states
perfill works but decode will complain
# if users use bf16 states
both perfill and decode work

@yuan-luo
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants