@@ -2748,6 +2748,10 @@ def run_moe_test(
27482748 permute_info , scores = routing_reference_sigmoid_renorm (
27492749 expert_logits , top_k , num_experts , padding , norm_topk_prob = norm_topk_prob
27502750 )
2751+ elif routing_method_type == RoutingMethodType .Sigmoid :
2752+ permute_info , scores = routing_reference_sigmoid_renorm (
2753+ expert_logits , top_k , num_experts , padding , norm_topk_prob = False
2754+ )
27512755 elif routing_method_type == RoutingMethodType .Llama4 :
27522756 permute_info , scores = routing_reference_no_aux (
27532757 expert_logits ,
@@ -3242,6 +3246,119 @@ def test_sigmoid_renorm_routing(
32423246 )
32433247
32443248
3249+ # Test: Sigmoid routing (Sigmoid -> TopK, no renormalization)
3250+ @pytest .mark .parametrize ("num_tokens" , [8 , 768 , 3072 ])
3251+ @pytest .mark .parametrize ("hidden_size" , [1024 ])
3252+ @pytest .mark .parametrize ("intermediate_size" , [1024 , 768 , 512 , 384 ])
3253+ @pytest .mark .parametrize (
3254+ "moe_impl" ,
3255+ [
3256+ pytest .param (BF16Moe (), id = "BF16xBF16" ),
3257+ pytest .param (
3258+ FP8BlockScaleMoe (fp8_quantization_type = QuantMode .FP8_BLOCK_SCALE_DEEPSEEK ),
3259+ id = "FP8_Block_DeepSeek" ,
3260+ ),
3261+ pytest .param (
3262+ FP8BlockScaleMoe (fp8_quantization_type = QuantMode .FP8_BLOCK_SCALE_MXFP8 ),
3263+ id = "FP8_Block_MxFp8" ,
3264+ ),
3265+ pytest .param (FP8PerTensorMoe (), id = "FP8_Tensor" ),
3266+ pytest .param (FP4Moe (quant_mode = QuantMode .FP4_NVFP4_NVFP4 ), id = "NvFP4xNvFP4" ),
3267+ pytest .param (FP4Moe (quant_mode = QuantMode .FP4_MXFP4_MXFP8 ), id = "MxFP4xMxFP8" ),
3268+ pytest .param (FP4Moe (quant_mode = QuantMode .FP4_MXFP4_Bf16 ), id = "MxFP4xBf16" ),
3269+ pytest .param (MxInt4BlockScaleMoe (), id = "MxInt4xBf16" ),
3270+ ],
3271+ )
3272+ @pytest .mark .parametrize (
3273+ "routing_config" ,
3274+ [
3275+ pytest .param (
3276+ {
3277+ "num_experts" : 128 ,
3278+ "top_k" : 8 ,
3279+ "padding" : 8 ,
3280+ "n_groups" : None ,
3281+ "top_k_groups" : None ,
3282+ "routed_scaling" : None ,
3283+ "has_routing_bias" : False ,
3284+ "routing_method_type" : RoutingMethodType .Sigmoid ,
3285+ "compatible_moe_impls" : [
3286+ FP8PerTensorMoe ,
3287+ FP8BlockScaleMoe ,
3288+ FP4Moe ,
3289+ BF16Moe ,
3290+ MxInt4BlockScaleMoe ,
3291+ ],
3292+ "compatible_intermediate_size" : [384 , 768 , 1024 ],
3293+ "enable_autotune" : True ,
3294+ },
3295+ id = "Sigmoid_128e_top8" ,
3296+ ),
3297+ ],
3298+ )
3299+ @pytest .mark .parametrize (
3300+ "weight_processing" ,
3301+ [
3302+ pytest .param (
3303+ {
3304+ "use_shuffled_weight" : False ,
3305+ "layout" : WeightLayout .MajorK ,
3306+ "compatible_moe_impls" : [FP8BlockScaleMoe ],
3307+ },
3308+ id = "NoShuffle_MajorK" ,
3309+ ),
3310+ pytest .param (
3311+ {
3312+ "use_shuffled_weight" : True ,
3313+ "layout" : WeightLayout .MajorK ,
3314+ "compatible_moe_impls" : [FP4Moe , FP8PerTensorMoe , FP8BlockScaleMoe ],
3315+ },
3316+ id = "Shuffled_MajorK" ,
3317+ ),
3318+ pytest .param (
3319+ {
3320+ "use_shuffled_weight" : True ,
3321+ "layout" : WeightLayout .BlockMajorK ,
3322+ "compatible_moe_impls" : [
3323+ FP8BlockScaleMoe ,
3324+ BF16Moe ,
3325+ MxInt4BlockScaleMoe ,
3326+ ],
3327+ },
3328+ id = "Shuffled_BlockMajorK" ,
3329+ ),
3330+ ],
3331+ )
3332+ @pytest .mark .parametrize (
3333+ "activation_type" ,
3334+ [
3335+ pytest .param (ActivationType .Swiglu .value , id = "Swiglu" ),
3336+ pytest .param (ActivationType .Geglu .value , id = "Geglu" ),
3337+ ],
3338+ )
3339+ def test_sigmoid_routing (
3340+ num_tokens ,
3341+ hidden_size ,
3342+ intermediate_size ,
3343+ moe_impl ,
3344+ routing_config ,
3345+ weight_processing ,
3346+ activation_type ,
3347+ cache_permute_indices ,
3348+ ):
3349+ """Test Sigmoid routing configurations (Sigmoid -> TopK, no renormalization)."""
3350+ run_moe_test (
3351+ num_tokens ,
3352+ hidden_size ,
3353+ intermediate_size ,
3354+ moe_impl ,
3355+ routing_config ,
3356+ weight_processing ,
3357+ activation_type ,
3358+ cache_permute_indices ,
3359+ )
3360+
3361+
32453362# Test: DeepSeekV3 routing
32463363@pytest .mark .parametrize ("num_tokens" , [8 , 768 , 3072 ])
32473364@pytest .mark .parametrize ("hidden_size" , [1024 ])
0 commit comments