Skip to content

Commit 78477c9

Browse files
author
luoyuan.luo
committed
Address review comments
1 parent b4f1995 commit 78477c9

File tree

4 files changed

+34
-39
lines changed

4 files changed

+34
-39
lines changed

python/sglang/jit_kernel/benchmark/bench_per_token_group_quant_8bit.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,7 @@
2929

3030
mode_concentrated = IS_CI or (os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated")
3131

32-
if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")):
33-
configs = [
34-
[
35-
768 * 8,
36-
2048,
37-
128,
38-
48,
39-
fp8_type_,
40-
dict(
41-
column_major_scales=True,
42-
scale_tma_aligned=True,
43-
scale_ue8m0=True,
44-
fuse_silu_and_mul=True,
45-
# masked_layout_mode=None,
46-
masked_layout_mode="balanced",
47-
# masked_layout_mode="extreme",
48-
),
49-
]
50-
]
51-
elif mode_concentrated:
32+
if mode_concentrated:
5233
configs = list(
5334
itertools.product(
5435
[768],

python/sglang/jit_kernel/per_token_group_quant_8bit.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,15 @@
99
if TYPE_CHECKING:
1010
from tvm_ffi.module import Module
1111

12-
_OUTPUT_DTYPE_MAP = {
13-
torch.float8_e4m3fn: "fp8_e4m3_t",
14-
torch.int8: "int8_t",
15-
}
12+
from sglang.jit_kernel.utils import CPP_DTYPE_MAP as OUTPUT_DTYPE_MAP
1613

1714

1815
@cache_once
1916
def _jit_per_token_group_quant_8bit_module(
2017
dtype: torch.dtype, output_type: torch.dtype
2118
) -> Module:
2219
input_args = make_cpp_args(dtype)
23-
out_cpp = _OUTPUT_DTYPE_MAP[output_type]
20+
out_cpp = OUTPUT_DTYPE_MAP[output_type]
2421
return load_jit(
2522
"per_token_group_quant_8bit",
2623
cuda_files=["gemm/per_token_group_quant_8bit.cuh"],

python/sglang/jit_kernel/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
if TYPE_CHECKING:
1111
from tvm_ffi import Module
1212

13-
1413
F = TypeVar("F", bound=Callable[..., Any])
1514

1615

@@ -73,7 +72,9 @@ def __str__(self) -> str:
7372
CPP_DTYPE_MAP = {
7473
torch.float: "fp32_t",
7574
torch.float16: "fp16_t",
75+
torch.float8_e4m3fn: "fp8_e4m3_t",
7676
torch.bfloat16: "bf16_t",
77+
torch.int8: "int8_t",
7778
}
7879

7980

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363

6464
enable_sgl_per_token_group_quant_8bit = False
6565

66+
from sglang.jit_kernel.per_token_group_quant_8bit import (
67+
per_token_group_quant_8bit as sgl_per_token_group_quant_8bit_jit,
68+
)
69+
6670
if _is_hip:
6771
_has_vllm = False
6872
if _use_aiter:
@@ -501,19 +505,31 @@ def sglang_per_token_group_quant_fp8(
501505
if x.shape[0] > 0:
502506
# Temporary
503507
if enable_sgl_per_token_group_quant_8bit:
504-
sgl_per_token_group_quant_8bit(
505-
x,
506-
x_q,
507-
x_s,
508-
group_size,
509-
eps,
510-
fp8_min,
511-
fp8_max,
512-
scale_ue8m0,
513-
fuse_silu_and_mul,
514-
masked_m,
515-
enable_v2=enable_v2,
516-
)
508+
if enable_v2:
509+
sgl_per_token_group_quant_8bit(
510+
x,
511+
x_q,
512+
x_s,
513+
group_size,
514+
eps,
515+
fp8_min,
516+
fp8_max,
517+
scale_ue8m0,
518+
fuse_silu_and_mul,
519+
masked_m,
520+
enable_v2=True,
521+
)
522+
else:
523+
sgl_per_token_group_quant_8bit_jit(
524+
input=x,
525+
output_q=x_q,
526+
output_s=x_s,
527+
group_size=group_size,
528+
eps=eps,
529+
fp8_min=fp8_min,
530+
fp8_max=fp8_max,
531+
scale_ue8m0=scale_ue8m0,
532+
)
517533
else:
518534
assert not enable_v2
519535
sgl_per_token_group_quant_fp8(

0 commit comments

Comments
 (0)