Skip to content

Commit 12793cf

Browse files
committed
more
1 parent 54e21bb commit 12793cf

6 files changed

Lines changed: 76 additions & 12 deletions

File tree

docs_new/docs/advanced_features/server_arguments.mdx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,9 +1196,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s
11961196
</tr>
11971197
<tr>
11981198
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>`--fp4-gemm-backend`</td>
1199-
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Choose the runner backend for NVFP4 GEMM operations. Options: 'flashinfer_cutlass' (default), 'auto' (auto-selects between flashinfer_cudnn/flashinfer_cutlass based on CUDA/cuDNN version), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All backends are from FlashInfer; when FlashInfer is unavailable, sgl-kernel CUTLASS is used as an automatic fallback.</td>
1200-
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}><code>flashinfer_cutlass</code></td>
1201-
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}><code>auto</code>, <code>flashinfer_cudnn</code>, <code>flashinfer_cutlass</code>, <code>flashinfer_trtllm</code></td>
1199+
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}>Choose the runner backend for NVFP4 GEMM operations. Options: 'auto' (default; selects <code>flashinfer_cudnn</code> on SM120, <code>flashinfer_cutedsl</code> on SM100, <code>flashinfer_cutlass</code> otherwise), 'cutlass' (SGLang CUTLASS kernel), 'flashinfer_cutlass' (FlashInfer CUTLASS backend), 'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), 'flashinfer_cutedsl' (FlashInfer CuTe DSL backend), 'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). All FlashInfer backends fall back to sgl-kernel CUTLASS when FlashInfer is unavailable.</td>
1200+
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.02)"}}><code>auto</code></td>
1201+
<td style={{padding: "9px 12px", backgroundColor: "rgba(255,255,255,0.05)"}}><code>auto</code>, <code>cutlass</code>, <code>flashinfer_cudnn</code>, <code>flashinfer_cutedsl</code>, <code>flashinfer_cutlass</code>, <code>flashinfer_trtllm</code></td>
12021202
</tr>
12031203
<tr>
12041204
<td style={{padding: "9px 12px", fontWeight: 500, backgroundColor: "rgba(255,255,255,0.02)"}}>`--disable-flashinfer-autotune`</td>

python/sglang/srt/layers/quantization/fp4_utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing import TYPE_CHECKING
66

7-
from sglang.srt.utils.common import is_sm120_supported
7+
from sglang.srt.utils.common import is_sm100_supported, is_sm120_supported
88

99
if TYPE_CHECKING:
1010
from sglang.srt.server_args import ServerArgs
@@ -18,6 +18,7 @@ class Fp4GemmRunnerBackend(Enum):
1818
AUTO = "auto"
1919
CUTLASS = "cutlass"
2020
FLASHINFER_CUDNN = "flashinfer_cudnn"
21+
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
2122
FLASHINFER_CUTLASS = "flashinfer_cutlass"
2223
FLASHINFER_TRTLLM = "flashinfer_trtllm"
2324

@@ -36,6 +37,9 @@ def is_flashinfer_cutlass(self) -> bool:
3637
def is_flashinfer_trtllm(self) -> bool:
3738
return self == Fp4GemmRunnerBackend.FLASHINFER_TRTLLM
3839

40+
def is_flashinfer_cutedsl(self) -> bool:
41+
return self == Fp4GemmRunnerBackend.FLASHINFER_CUTEDSL
42+
3943
def is_flashinfer(self) -> bool:
4044
return self.value.startswith("flashinfer_")
4145

@@ -47,7 +51,10 @@ def get_flashinfer_backend(self) -> str:
4751
'flashinfer_trtllm' -> 'trtllm'
4852
'flashinfer_cutlass' -> 'cutlass'
4953
'flashinfer_cudnn' -> 'cudnn'
54+
'flashinfer_cutedsl' -> 'cute-dsl'
5055
"""
56+
if self == Fp4GemmRunnerBackend.FLASHINFER_CUTEDSL:
57+
return "cute-dsl"
5158
if self.value.startswith("flashinfer_"):
5259
return self.value.removeprefix("flashinfer_")
5360
else:
@@ -68,10 +75,8 @@ def initialize_fp4_gemm_config(server_args: ServerArgs) -> None:
6875
# heterogeneous batches on SM120 (Blackwell). cudnn is stable.
6976
# See: https://github.com/sgl-project/sglang/issues/20043
7077
backend = "flashinfer_cudnn"
71-
logger.info(
72-
"SM120 (Blackwell) detected: auto-selecting "
73-
"fp4-gemm-backend=flashinfer_cudnn"
74-
)
78+
elif is_sm100_supported():
79+
backend = "flashinfer_cutedsl"
7580
else:
7681
backend = "flashinfer_cutlass"
7782

python/sglang/srt/server_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@
205205
"auto",
206206
"cutlass",
207207
"flashinfer_cudnn",
208+
"flashinfer_cutedsl",
208209
"flashinfer_cutlass",
209210
"flashinfer_trtllm",
210211
]
@@ -5196,10 +5197,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
51965197
default=ServerArgs.fp4_gemm_runner_backend,
51975198
dest="fp4_gemm_runner_backend",
51985199
help="Choose the runner backend for NVFP4 GEMM operations. "
5199-
"Options: 'auto' (default; selects flashinfer_cudnn on SM120, flashinfer_cutlass otherwise), "
5200+
"Options: 'auto' (default; selects flashinfer_cudnn on SM120, flashinfer_cutedsl on SM100, flashinfer_cutlass otherwise), "
52005201
"'cutlass' (SGLang CUTLASS kernel), "
52015202
"'flashinfer_cutlass' (FlashInfer CUTLASS backend), "
52025203
"'flashinfer_cudnn' (FlashInfer cuDNN backend, optimal on CUDA 13+ with cuDNN 9.15+), "
5204+
"'flashinfer_cutedsl' (FlashInfer CuTe DSL backend), "
52035205
"'flashinfer_trtllm' (FlashInfer TensorRT-LLM backend, requires different weight preparation with shuffling). ",
52045206
)
52055207
parser.add_argument(

python/sglang/srt/utils/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,11 @@ def configure_logger(server_args, prefix: str = ""):
12061206
for name in ("httpx", "httpcore"):
12071207
logging.getLogger(name).setLevel(logging.WARNING)
12081208

1209+
if is_flashinfer_available():
1210+
from flashinfer.jit.core import logger as flashinfer_logger
1211+
1212+
flashinfer_logger.setLevel(logging.ERROR)
1213+
12091214

12101215
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
12111216
def replace_submodule(

sgl-kernel/benchmark/bench_fp4_gemm.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import argparse
22
import csv
3-
import os
3+
import logging
44
from functools import partial
55
from typing import List, Tuple
66

77
import torch
88
import triton
99
from flashinfer import mm_fp4
10+
from flashinfer.autotuner import autotune
11+
from flashinfer.jit.core import logger as flashinfer_logger
1012
from flashinfer.testing import bench_gpu_time
1113

14+
flashinfer_logger.setLevel(logging.ERROR)
15+
1216
from sglang.jit_kernel.nvfp4 import cutlass_scaled_fp4_mm, scaled_fp4_quant
1317
from sglang.srt.utils import (
1418
get_device_capability,
@@ -150,23 +154,25 @@ def _run_mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, dtype, res_fi, backend):
150154
x_log=False,
151155
line_arg="provider",
152156
line_vals=(
153-
["sglang_cutlass", "cutlass", "cudnn", "trtllm", "auto"]
157+
["sglang_cutlass", "cutlass", "cudnn", "trtllm", "cute-dsl", "auto"]
154158
if is_sm100_supported()
155-
else ["sglang_cutlass", "cutlass", "cudnn", "auto"]
159+
else ["sglang_cutlass", "cutlass", "cudnn", "cute-dsl", "auto"]
156160
),
157161
line_names=(
158162
[
159163
"sglang cutlass fp4",
160164
"flashinfer cutlass fp4",
161165
"cudnn fp4",
162166
"trtllm fp4",
167+
"cute-dsl fp4",
163168
"auto fp4 (cudnn/cutlass)",
164169
]
165170
if is_sm100_supported()
166171
else [
167172
"sglang cutlass fp4",
168173
"flashinfer cutlass fp4",
169174
"cudnn fp4",
175+
"cute-dsl fp4",
170176
"auto fp4",
171177
]
172178
),
@@ -176,13 +182,15 @@ def _run_mm_fp4(a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, dtype, res_fi, backend):
176182
("orange", "solid"),
177183
("blue", "solid"),
178184
("green", "solid"),
185+
("brown", "solid"),
179186
("purple", "solid"),
180187
]
181188
if is_sm100_supported()
182189
else [
183190
("red", "solid"),
184191
("orange", "solid"),
185192
("blue", "solid"),
193+
("brown", "solid"),
186194
("purple", "solid"),
187195
]
188196
),
@@ -224,6 +232,11 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
224232
use_cuda_graph=True,
225233
)
226234
elif provider == "cutlass":
235+
with autotune():
236+
_run_mm_fp4(
237+
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
238+
alpha, dtype, res_fi, backend="cutlass",
239+
)
227240
times_ms = bench_gpu_time(
228241
fn=partial(_run_mm_fp4, backend="cutlass"),
229242
input_args=(
@@ -238,6 +251,11 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
238251
use_cuda_graph=True,
239252
)
240253
elif provider == "cudnn":
254+
with autotune():
255+
_run_mm_fp4(
256+
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
257+
alpha, dtype, res_fi, backend="cudnn",
258+
)
241259
times_ms = bench_gpu_time(
242260
fn=partial(_run_mm_fp4, backend="cudnn"),
243261
input_args=(
@@ -254,12 +272,41 @@ def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
254272
elif provider == "trtllm":
255273
a_sf_u8 = a_scale_interleaved.to(torch.uint8)
256274
b_sf_u8_T = b_sf_T.to(torch.uint8)
275+
with autotune():
276+
_run_mm_fp4(
277+
a_fp4, b_fp4_T, a_sf_u8, b_sf_u8_T,
278+
alpha, dtype, res_fi, backend="trtllm",
279+
)
257280
times_ms = bench_gpu_time(
258281
fn=partial(_run_mm_fp4, backend="trtllm"),
259282
input_args=(a_fp4, b_fp4_T, a_sf_u8, b_sf_u8_T, alpha, dtype, res_fi),
260283
use_cuda_graph=True,
261284
)
285+
elif provider == "cute-dsl":
286+
with autotune():
287+
_run_mm_fp4(
288+
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
289+
alpha, dtype, res_fi, backend="cute-dsl",
290+
)
291+
times_ms = bench_gpu_time(
292+
fn=partial(_run_mm_fp4, backend="cute-dsl"),
293+
input_args=(
294+
a_fp4,
295+
b_fp4_T,
296+
a_scale_interleaved,
297+
b_sf_T,
298+
alpha,
299+
dtype,
300+
res_fi,
301+
),
302+
use_cuda_graph=True,
303+
)
262304
elif provider == "auto":
305+
with autotune():
306+
_run_mm_fp4(
307+
a_fp4, b_fp4_T, a_scale_interleaved, b_sf_T,
308+
alpha, dtype, res_fi, backend="auto",
309+
)
263310
times_ms = bench_gpu_time(
264311
fn=partial(_run_mm_fp4, backend="auto"),
265312
input_args=(

test/registered/quant/test_nvfp4_gemm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,5 +81,10 @@ class TestFP4GemmFlashinferTrtllm(FP4GemmBase, unittest.TestCase):
8181
backend = "flashinfer_trtllm"
8282

8383

84+
@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
85+
class TestFP4GemmFlashinferCutedsl(FP4GemmBase, unittest.TestCase):
86+
backend = "flashinfer_cutedsl"
87+
88+
8489
if __name__ == "__main__":
8590
unittest.main()

0 commit comments

Comments
 (0)