File tree Expand file tree Collapse file tree 4 files changed +34
-39
lines changed
Expand file tree Collapse file tree 4 files changed +34
-39
lines changed Original file line number Diff line number Diff line change 2929
3030mode_concentrated = IS_CI or (os .environ .get ("SGLANG_BENCH_MODE" , "" ) == "concentrated" )
3131
32- if int (os .environ .get ("SGLANG_NSYS_PROFILING" , "0" )):
33- configs = [
34- [
35- 768 * 8 ,
36- 2048 ,
37- 128 ,
38- 48 ,
39- fp8_type_ ,
40- dict (
41- column_major_scales = True ,
42- scale_tma_aligned = True ,
43- scale_ue8m0 = True ,
44- fuse_silu_and_mul = True ,
45- # masked_layout_mode=None,
46- masked_layout_mode = "balanced" ,
47- # masked_layout_mode="extreme",
48- ),
49- ]
50- ]
51- elif mode_concentrated :
32+ if mode_concentrated :
5233 configs = list (
5334 itertools .product (
5435 [768 ],
Original file line number Diff line number Diff line change 99if TYPE_CHECKING :
1010 from tvm_ffi .module import Module
1111
12- _OUTPUT_DTYPE_MAP = {
13- torch .float8_e4m3fn : "fp8_e4m3_t" ,
14- torch .int8 : "int8_t" ,
15- }
12+ from sglang .jit_kernel .utils import CPP_DTYPE_MAP as OUTPUT_DTYPE_MAP
1613
1714
1815@cache_once
1916def _jit_per_token_group_quant_8bit_module (
2017 dtype : torch .dtype , output_type : torch .dtype
2118) -> Module :
2219 input_args = make_cpp_args (dtype )
23- out_cpp = _OUTPUT_DTYPE_MAP [output_type ]
20+ out_cpp = OUTPUT_DTYPE_MAP [output_type ]
2421 return load_jit (
2522 "per_token_group_quant_8bit" ,
2623 cuda_files = ["gemm/per_token_group_quant_8bit.cuh" ],
Original file line number Diff line number Diff line change 1010if TYPE_CHECKING :
1111 from tvm_ffi import Module
1212
13-
1413F = TypeVar ("F" , bound = Callable [..., Any ])
1514
1615
@@ -73,7 +72,9 @@ def __str__(self) -> str:
7372CPP_DTYPE_MAP = {
7473 torch .float : "fp32_t" ,
7574 torch .float16 : "fp16_t" ,
75+ torch .float8_e4m3fn : "fp8_e4m3_t" ,
7676 torch .bfloat16 : "bf16_t" ,
77+ torch .int8 : "int8_t" ,
7778}
7879
7980
Original file line number Diff line number Diff line change 6363
6464 enable_sgl_per_token_group_quant_8bit = False
6565
66+ from sglang .jit_kernel .per_token_group_quant_8bit import (
67+ per_token_group_quant_8bit as sgl_per_token_group_quant_8bit_jit ,
68+ )
69+
6670if _is_hip :
6771 _has_vllm = False
6872 if _use_aiter :
@@ -501,19 +505,31 @@ def sglang_per_token_group_quant_fp8(
501505 if x .shape [0 ] > 0 :
502506 # Temporary
503507 if enable_sgl_per_token_group_quant_8bit :
504- sgl_per_token_group_quant_8bit (
505- x ,
506- x_q ,
507- x_s ,
508- group_size ,
509- eps ,
510- fp8_min ,
511- fp8_max ,
512- scale_ue8m0 ,
513- fuse_silu_and_mul ,
514- masked_m ,
515- enable_v2 = enable_v2 ,
516- )
508+ if enable_v2 :
509+ sgl_per_token_group_quant_8bit (
510+ x ,
511+ x_q ,
512+ x_s ,
513+ group_size ,
514+ eps ,
515+ fp8_min ,
516+ fp8_max ,
517+ scale_ue8m0 ,
518+ fuse_silu_and_mul ,
519+ masked_m ,
520+ enable_v2 = True ,
521+ )
522+ else :
523+ sgl_per_token_group_quant_8bit_jit (
524+ input = x ,
525+ output_q = x_q ,
526+ output_s = x_s ,
527+ group_size = group_size ,
528+ eps = eps ,
529+ fp8_min = fp8_min ,
530+ fp8_max = fp8_max ,
531+ scale_ue8m0 = scale_ue8m0 ,
532+ )
517533 else :
518534 assert not enable_v2
519535 sgl_per_token_group_quant_fp8 (
You can’t perform that action at this time.
0 commit comments