Skip to content

Commit 4bffd3a

Browse files
authored
[GPT-OSS] support fp8 online quantization for gpt-oss bf16 (#18988)
merge it as all required CI passed
1 parent 96bae23 commit 4bffd3a

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,7 @@ def __init__(self, quant_config: Fp8Config):
677677
self.block_quant = (
678678
self.use_mxfp8 or self.quant_config.weight_block_size is not None
679679
)
680+
self.with_bias = False
680681
if get_moe_runner_backend().is_cutlass():
681682
assert (
682683
cutlass_fp8_supported()
@@ -706,8 +707,10 @@ def create_weights(
706707
hidden_size: int,
707708
intermediate_size_per_partition: int,
708709
params_dtype: torch.dtype,
710+
with_bias: bool = False,
709711
**extra_weight_attrs,
710712
):
713+
self.with_bias = with_bias
711714
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
712715

713716
if self.quant_config.is_checkpoint_fp8_serialized:
@@ -782,6 +785,27 @@ def create_weights(
782785
layer.register_parameter("w2_weight", w2_weight)
783786
set_weight_attrs(w2_weight, extra_weight_attrs)
784787

788+
# BIAS (optional, e.g. GPT-OSS)
789+
if self.with_bias:
790+
w13_up_dim = (
791+
2 * intermediate_size_per_partition
792+
if layer.moe_runner_config.is_gated
793+
else intermediate_size_per_partition
794+
)
795+
w13_weight_bias = torch.nn.Parameter(
796+
torch.empty(num_experts, w13_up_dim, dtype=torch.float32),
797+
requires_grad=False,
798+
)
799+
layer.register_parameter("w13_weight_bias", w13_weight_bias)
800+
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
801+
802+
w2_weight_bias = torch.nn.Parameter(
803+
torch.empty(num_experts, hidden_size, dtype=torch.float32),
804+
requires_grad=False,
805+
)
806+
layer.register_parameter("w2_weight_bias", w2_weight_bias)
807+
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
808+
785809
# WEIGHT_SCALES
786810
if self.block_quant:
787811
scale_dtype = torch.uint8 if self.use_mxfp8 else torch.float32
@@ -1507,6 +1531,8 @@ def apply(
15071531
quant_info = TritonMoeQuantInfo(
15081532
w13_weight=layer.w13_weight,
15091533
w2_weight=layer.w2_weight,
1534+
b13=getattr(layer, "w13_weight_bias", None),
1535+
b2=getattr(layer, "w2_weight_bias", None),
15101536
use_fp8_w8a8=True,
15111537
w13_scale=(
15121538
layer.w13_weight_scale_inv

python/sglang/srt/server_args.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,11 @@ def _handle_model_specific_adjustments(self):
13901390
logger.warning(
13911391
"Detected ROCm with SGLANG_USE_AITER for GPT-OSS bf16 model, using triton MOE kernel."
13921392
)
1393-
elif self.ep_size == 1 and is_triton_kernels_available():
1393+
elif (
1394+
self.ep_size == 1
1395+
and is_triton_kernels_available()
1396+
and self.quantization is None
1397+
):
13941398
self.moe_runner_backend = "triton_kernel"
13951399
logger.warning(
13961400
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."

0 commit comments

Comments
 (0)