Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \
"runai-model-streamer[s3,gcs,azure]>=0.15.7"

RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --force-reinstall --no-deps;
python3 -m pip install "nvidia-cutlass-dsl>=4.4.1" "nvidia-cutlass-dsl-libs-base>=4.4.1" --force-reinstall --no-deps; \
if [ "${CUDA_VERSION%%.*}" = "13" ]; then \
python3 -m pip install "nvidia-cutlass-dsl-libs-cu13>=4.4.2" --no-deps ; \
fi

# Patching packages for CUDA 12/13 compatibility
# TODO: Remove when torch version covers these packages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Both SM90 and SM100+ use the same pool layout: [pool, HV, V, K] (K-last).

SM90 (Hopper): full support — decode, prefill, MTP. State dtype: fp32.
SM100+ (Blackwell+): decode-only with bf16 state. More support on the way.
SM100+ (Blackwell+): decode and prefill with bf16 state. MTP verify on the way.

Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+).
"""
Expand Down Expand Up @@ -74,8 +74,7 @@ class FlashInferGDNKernel(LinearAttnKernelBase):
"""FlashInfer kernel for GDN with K-last SSM state layout.

SM90 (Hopper): decode uses gather/scatter; prefill and MTP verify supported.
SM100+ (Blackwell+): decode uses pool API (initial_state_indices); prefill
and MTP verify are not supported (use Triton backend for those).
SM100+ (Blackwell+): decode and prefill supported; MTP verify not yet supported.

Requires flashinfer >= 0.6.4 (SM90) or >= 0.6.5 (SM100+).
"""
Expand All @@ -97,7 +96,7 @@ def __init__(self):
raise RuntimeError("FlashInfer GDN decode kernel is unavailable.")

sm_major = torch.cuda.get_device_capability()[0]
self.use_state_pool = sm_major != 9
self.is_sm100plus = sm_major != 9

if sm_major == 9:
if self._prefill_fn is None:
Expand Down Expand Up @@ -136,7 +135,7 @@ def decode(
a_fi = a.view(batch_size, 1, num_v_heads)
b_fi = b.view(batch_size, 1, num_v_heads)

if self.use_state_pool:
if self.is_sm100plus:
output_fi, _ = self._decode_fn(
q=query_fi,
k=key_fi,
Expand Down Expand Up @@ -186,13 +185,6 @@ def extend(
query_start_loc: torch.Tensor,
**kwargs,
) -> tuple:
if self.use_state_pool:
raise NotImplementedError(
"FlashInfer GDN prefill is not supported on SM100+. "
"Use --linear-attn-prefill-backend triton."
)

# SM90: chunked prefill using FlashInfer GDN prefill kernel.
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd

total_seq_len = q.shape[1]
Expand All @@ -207,30 +199,57 @@ def extend(
alpha_fi = torch.exp(g[0].to(torch.float32))
beta_fi = beta[0].to(torch.float32)

cu_seqlens_fi = query_start_loc.to(torch.int64)

# Remap negative padding indices to sentinel slot
ssm_cache_indices = torch.where(
cache_indices >= 0,
cache_indices,
ssm_states.shape[0] - 1,
).to(torch.int64)

# FlashInfer requires float32 initial state, K-last layout [B, HV, V, K]
initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32)

output_fi, output_state_fi = self._prefill_fn(
q=q_fi,
k=k_fi,
v=v_fi,
g=alpha_fi,
beta=beta_fi,
scale=None,
initial_state=initial_state_fi,
output_final_state=True,
cu_seqlens=cu_seqlens_fi,
use_qk_l2norm_in_kernel=False,
)
if self.is_sm100plus:
# Negative indices (e.g. -1) are padding markers for slots not yet
# assigned to a real sequence; clamp them to 0 (the reserved dummy
# slot) so the FlashInfer kernel never reads out-of-bounds state.
ssm_cache_indices = cache_indices.clamp(min=0).to(torch.int64)
num_seqs = ssm_cache_indices.shape[0]
num_sab_heads = max(q.shape[2], num_v_heads)
head_k_dim = q.shape[3]
# Pre-allocate bf16 output_state so the kernel compiles and writes the
# bf16 state path directly, avoiding a fp32 allocation and a subsequent
# fp32->bf16 conversion in the scatter step.
output_state_fi = torch.empty(
(num_seqs, num_sab_heads, head_v_dim, head_k_dim),
dtype=ssm_states.dtype,
device=ssm_states.device,
)
initial_state_fi = ssm_states[ssm_cache_indices].contiguous()
output_fi, output_state_fi = self._prefill_fn(
q=q_fi,
k=k_fi,
v=v_fi,
g=alpha_fi,
beta=beta_fi,
scale=None,
initial_state=initial_state_fi,
output_final_state=True,
cu_seqlens=query_start_loc, # already int32
use_qk_l2norm_in_kernel=False,
output_state=output_state_fi,
)
else:
# SM90: preserve original negative-index handling (remap to last slot).
ssm_cache_indices = torch.where(
cache_indices >= 0,
cache_indices,
ssm_states.shape[0] - 1,
).to(torch.int64)
# State must be float32; kernel requires int64 cu_seqlens.
initial_state_fi = ssm_states[ssm_cache_indices].to(torch.float32)
output_fi, output_state_fi = self._prefill_fn(
q=q_fi,
k=k_fi,
v=v_fi,
g=alpha_fi,
beta=beta_fi,
scale=None,
initial_state=initial_state_fi,
output_final_state=True,
cu_seqlens=query_start_loc.to(torch.int64),
use_qk_l2norm_in_kernel=False,
)

# Write back state to pool
ssm_states.index_copy_(
Expand Down Expand Up @@ -267,7 +286,7 @@ def target_verify(
retrieve_parent_token: torch.Tensor,
**kwargs,
) -> torch.Tensor:
if self.use_state_pool:
if self.is_sm100plus:
raise NotImplementedError(
"FlashInfer GDN MTP verify is not yet supported on SM100+."
)
Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2722,6 +2722,20 @@ def _handle_linear_attn_backend(self):
f"got {self.mamba_ssm_dtype!r}"
)

# SM100+ FlashInfer GDN prefill requires CUDA 13+ (CuTe DSL kernel)
# for correctness and best performance.
prefill = self.linear_attn_prefill_backend or self.linear_attn_backend
if (
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better add bf16 state dtype validation for SM100+ FlashInfer prefill backend, just like how SM100+ FlashInfer decode backend does:

        if (
            decode == "flashinfer"
            and self.mamba_ssm_dtype != "bfloat16"
            and torch.cuda.is_available()
            and torch.cuda.get_device_capability()[0] >= 10
        ):

Otherwise, the user can then run SM100+ FlashInfer prefill with float32 state, which is unsupported (the module docstring states "SM100+: decode and prefill with bf16 state"), likely causing kernel errors or incorrect results at runtime.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the flashinfer prefill kernel actually supports the fp32. so the current status from flashinfer:

prefill: fp32/bf16
decode: bf16

Note, here I am talking about the "fast" kernels that we recommended for the blackwell (there are some "legacy" kernels that are not the focus of this PR).

So, the below is what is going to happen with the current code:

# if users use fp32 states
perfill works but decode will complain
# if users use bf16 states
both perfill and decode work

prefill == "flashinfer"
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 10
and int(torch.version.cuda.split(".")[0]) < 13
):
raise ValueError(
"--linear-attn-prefill-backend flashinfer on SM100+ requires CUDA 13+, "
f"got CUDA {torch.version.cuda}"
)

def _handle_context_parallelism(self):
if self.attn_cp_size > 1:
# The tp_size is the world size, not the real tensor parallel size
Expand Down
81 changes: 81 additions & 0 deletions test/registered/4-gpu-models/test_qwen35_fp4_flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest

import torch

from sglang.test.accuracy_test_runner import AccuracyTestParams
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.run_combined_tests import run_combined_tests
from sglang.test.test_utils import (
CustomTestCase,
ModelLaunchSettings,
)

register_cuda_ci(est_time=720, suite="stage-c-test-4-gpu-b200")

QWEN35_FP4_MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4"
ACC_THRESHOLDS = {QWEN35_FP4_MODEL: {"gsm8k": 0.95}}

_is_sm100_cuda13 = (
torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 10
and int(torch.version.cuda.split(".")[0]) >= 13
)


@unittest.skipUnless(_is_sm100_cuda13, "requires SM100+ GPU and CUDA 13+")
class TestQwen35FP4FlashInfer(CustomTestCase):
def test_gsm8k(self):
base_args = [
"--tp-size",
"4",
"--chunked-prefill-size",
"2048",
"--mamba-scheduler-strategy",
"extra_buffer",
"--mamba-track-interval",
"128",
"--mamba-ssm-dtype",
"bfloat16",
"--max-running-requests",
"128",
"--reasoning-parser",
"qwen3",
"--attention-backend",
"trtllm_mha",
"--quantization",
"modelopt_fp4",
"--model-loader-extra-config",
'{"enable_multithread_load": true,"num_threads": 64}',
"--linear-attn-decode-backend",
"flashinfer",
"--linear-attn-prefill-backend",
"flashinfer",
]

variants = [
ModelLaunchSettings(
QWEN35_FP4_MODEL,
extra_args=base_args,
variant="FlashInfer",
),
]

run_combined_tests(
models=variants,
test_name="Qwen3.5-397B-A17B-NVFP4",
accuracy_params=AccuracyTestParams(
dataset="gsm8k",
baseline_accuracy=ACC_THRESHOLDS[QWEN35_FP4_MODEL]["gsm8k"],
num_examples=200,
num_threads=128,
max_tokens=16000,
thinking_mode="qwen3",
temperature=0.6,
top_p=0.95,
top_k=20,
),
)


if __name__ == "__main__":
unittest.main()
6 changes: 0 additions & 6 deletions test/registered/4-gpu-models/test_qwen35_fp4_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ def test_gsm8k(self):
extra_args=base_args,
variant="Triton",
),
# TODO: Fix this and re-enable it
# ModelLaunchSettings(
# QWEN35_FP4_MODEL,
# extra_args=base_args + ["--linear-attn-decode-backend", "flashinfer"],
# variant="FlashInfer",
# ),
]

run_combined_tests(
Expand Down
Loading