@@ -677,6 +677,7 @@ def __init__(self, quant_config: Fp8Config):
677677 self .block_quant = (
678678 self .use_mxfp8 or self .quant_config .weight_block_size is not None
679679 )
680+ self .with_bias = False
680681 if get_moe_runner_backend ().is_cutlass ():
681682 assert (
682683 cutlass_fp8_supported ()
@@ -706,8 +707,10 @@ def create_weights(
706707 hidden_size : int ,
707708 intermediate_size_per_partition : int ,
708709 params_dtype : torch .dtype ,
710+ with_bias : bool = False ,
709711 ** extra_weight_attrs ,
710712 ):
713+ self .with_bias = with_bias
711714 from sglang .srt .layers .moe .fused_moe_triton import FusedMoeWeightScaleSupported
712715
713716 if self .quant_config .is_checkpoint_fp8_serialized :
@@ -782,6 +785,27 @@ def create_weights(
782785 layer .register_parameter ("w2_weight" , w2_weight )
783786 set_weight_attrs (w2_weight , extra_weight_attrs )
784787
788+ # BIAS (optional, e.g. GPT-OSS)
789+ if self .with_bias :
790+ w13_up_dim = (
791+ 2 * intermediate_size_per_partition
792+ if layer .moe_runner_config .is_gated
793+ else intermediate_size_per_partition
794+ )
795+ w13_weight_bias = torch .nn .Parameter (
796+ torch .empty (num_experts , w13_up_dim , dtype = torch .float32 ),
797+ requires_grad = False ,
798+ )
799+ layer .register_parameter ("w13_weight_bias" , w13_weight_bias )
800+ set_weight_attrs (w13_weight_bias , extra_weight_attrs )
801+
802+ w2_weight_bias = torch .nn .Parameter (
803+ torch .empty (num_experts , hidden_size , dtype = torch .float32 ),
804+ requires_grad = False ,
805+ )
806+ layer .register_parameter ("w2_weight_bias" , w2_weight_bias )
807+ set_weight_attrs (w2_weight_bias , extra_weight_attrs )
808+
785809 # WEIGHT_SCALES
786810 if self .block_quant :
787811 scale_dtype = torch .uint8 if self .use_mxfp8 else torch .float32
@@ -1507,6 +1531,8 @@ def apply(
15071531 quant_info = TritonMoeQuantInfo (
15081532 w13_weight = layer .w13_weight ,
15091533 w2_weight = layer .w2_weight ,
1534+ b13 = getattr (layer , "w13_weight_bias" , None ),
1535+ b2 = getattr (layer , "w2_weight_bias" , None ),
15101536 use_fp8_w8a8 = True ,
15111537 w13_scale = (
15121538 layer .w13_weight_scale_inv
0 commit comments