2020 current_platform .is_cuda () and current_platform .is_device_capability_family (100 )
2121)
2222
23+ TRTLLM_GEN_MXFP8_AVAILABLE = TRTLLM_GEN_MXFP4_AVAILABLE
24+
2325HOPPER_MXFP4_BF16_AVAILABLE = (
2426 current_platform .is_cuda ()
2527 and current_platform .is_device_capability (90 )
3436 shuffle_matrix_a ,
3537 shuffle_matrix_sf_a ,
3638 trtllm_fp4_block_scale_moe ,
39+ trtllm_fp8_block_scale_moe ,
3740 )
3841 from flashinfer .fp4_quantization import nvfp4_block_scale_interleave
39- from flashinfer .fused_moe .core import get_w2_permute_indices_with_cache
42+
43+ if TRTLLM_GEN_MXFP8_AVAILABLE :
44+ from flashinfer .fused_moe .core import (
45+ Fp8QuantizationType ,
46+ get_w2_permute_indices_with_cache ,
47+ )
4048
4149
4250@dataclass
@@ -160,6 +168,7 @@ def reference_moe(
160168 beta ,
161169 limit ,
162170 act_type ,
171+ is_gated ,
163172):
164173 # renormalize routing
165174 experts = torch .topk (roouting_logits , k = topk , dim = - 1 , sorted = True )
@@ -170,7 +179,12 @@ def reference_moe(
170179 mlp1_weight = w13 [expert_indices , ...]
171180 mlp1_bias = bias13 [expert_indices , ...]
172181 t = torch .einsum ("beck,bk->bec" , mlp1_weight , t ) + mlp1_bias
173- t = swiglu (t , alpha = alpha , beta = beta , limit = limit )
182+ if is_gated :
183+ t = swiglu (t , alpha = alpha , beta = beta , limit = limit )
184+ else :
185+ # RELU2_NO_MUL: relu(x)^2
186+ t = torch .relu (t )
187+ t = t * t
174188
175189 if act_type == "mxfp8" :
176190 t_quantized , t_scale = mxfp8_quantize (
@@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
569583 beta ,
570584 limit ,
571585 act_type ,
586+ is_gated = True ,
572587 )
573588 ref_result [start_idx :end_idx ].copy_ (chunk_result )
574589
@@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
705720 beta ,
706721 limit ,
707722 "bf16" ,
723+ is_gated = True ,
708724 )
709725
710726 from vllm .utils .flashinfer import flashinfer_cutlass_fused_moe
@@ -890,6 +906,7 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
890906 beta ,
891907 limit ,
892908 "mxfp8" ,
909+ is_gated = True ,
893910 )
894911
895912 # Prepare inputs for FlashInfer CUTLASS fused MoE
@@ -965,3 +982,169 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
965982
966983 # Allow some mismatch due to MXFP4 quantization
967984 check_accuracy (ref , out , atol = 0 , rtol = 0.3 , percent = 0.8 )
985+
986+
987+ @pytest .mark .parametrize ("topk" , [1 , 4 ])
988+ @pytest .mark .parametrize ("num_experts" , [32 ])
989+ @pytest .mark .parametrize ("num_tokens" , [1 , 128 ])
990+ @pytest .mark .parametrize ("intermediate_size,hidden_size" , [(3072 , 3072 )])
991+ @pytest .mark .parametrize ("is_gated" , [True ], ids = ["gated" ])
992+ @pytest .mark .skipif (
993+ not TRTLLM_GEN_MXFP8_AVAILABLE ,
994+ reason = "nvidia gpu and compute capability sm100 is required for this test" ,
995+ )
996+ def test_trtllm_gen_mxfp8_block_scale_moe (
997+ topk : int ,
998+ num_experts : int ,
999+ num_tokens : int ,
1000+ intermediate_size : int ,
1001+ hidden_size : int ,
1002+ is_gated : bool ,
1003+ ):
1004+ torch .manual_seed (42 )
1005+ device = "cuda:0"
1006+
1007+ inter_size = intermediate_size * (2 if is_gated else 1 )
1008+
1009+ hidden_states = (
1010+ torch .randn (num_tokens , hidden_size , device = device , dtype = torch .bfloat16 ) / 20
1011+ )
1012+ w13 = (
1013+ torch .randn (
1014+ num_experts ,
1015+ inter_size ,
1016+ hidden_size ,
1017+ device = device ,
1018+ dtype = torch .bfloat16 ,
1019+ )
1020+ / 20
1021+ )
1022+ w2 = (
1023+ torch .randn (
1024+ num_experts ,
1025+ hidden_size ,
1026+ intermediate_size ,
1027+ device = device ,
1028+ dtype = torch .bfloat16 ,
1029+ )
1030+ / 20
1031+ )
1032+ router_logits = torch .rand (
1033+ num_tokens , num_experts , dtype = torch .float32 , device = device
1034+ )
1035+ router_logits_kernel = router_logits .to (torch .bfloat16 )
1036+
1037+ # Quantize weights to MXFP8 and normalize scales to [E, M, K//32].
1038+ w13_q , w13_scale = mxfp8_quantize (w13 , is_sf_swizzled_layout = False )
1039+ w2_q , w2_scale = mxfp8_quantize (w2 , is_sf_swizzled_layout = False )
1040+ if w13_scale .ndim == 1 :
1041+ w13_scale = w13_scale .view (
1042+ num_experts ,
1043+ inter_size ,
1044+ hidden_size // 32 ,
1045+ )
1046+ if w2_scale .ndim == 1 :
1047+ w2_scale = w2_scale .view (num_experts , hidden_size , intermediate_size // 32 )
1048+
1049+ # Quantize activations to MXFP8.
1050+ hidden_states_q , hidden_states_scale = mxfp8_quantize (
1051+ hidden_states , is_sf_swizzled_layout = False
1052+ )
1053+ if hidden_states_scale .ndim == 1 :
1054+ hidden_states_scale = hidden_states_scale .view (num_tokens , hidden_size // 32 )
1055+
1056+ # Reference output using dequantized tensors + MXFP8 intermediate quantization.
1057+ w13_ref = mxfp8_dequantize (w13_q , w13_scale ).to (torch .float32 )
1058+ w2_ref = mxfp8_dequantize (w2_q , w2_scale ).to (torch .float32 )
1059+ hidden_states_ref = mxfp8_dequantize (hidden_states_q , hidden_states_scale ).to (
1060+ torch .float32
1061+ )
1062+ bias13 = torch .zeros (
1063+ num_experts ,
1064+ intermediate_size * (2 if is_gated else 1 ),
1065+ device = device ,
1066+ )
1067+ bias2 = torch .zeros (num_experts , hidden_size , device = device )
1068+ ref = reference_moe (
1069+ router_logits_kernel .to (torch .float32 ),
1070+ topk ,
1071+ num_experts ,
1072+ hidden_states_ref ,
1073+ w13_ref ,
1074+ bias13 ,
1075+ w2_ref ,
1076+ bias2 ,
1077+ alpha = 1.0 ,
1078+ beta = 0.0 ,
1079+ limit = None ,
1080+ act_type = "mxfp8" ,
1081+ is_gated = is_gated ,
1082+ )
1083+
1084+ # Shuffle weights/scales with the same indexed layout used by TRTLLM kernels.
1085+ epilogue_tile_m = 128
1086+ gemm1_weights_shuffled = []
1087+ gemm1_scales_shuffled = []
1088+ gemm2_weights_shuffled = []
1089+ gemm2_scales_shuffled = []
1090+ for i in range (num_experts ):
1091+ w13_rows = intermediate_size * (2 if is_gated else 1 )
1092+ w13_interleaved = w13_q [i ].clone ().reshape (w13_rows , - 1 )
1093+ w13_scale_interleaved = w13_scale [i ].clone ().reshape (w13_rows , - 1 )
1094+ if is_gated :
1095+ w13_interleaved = reorder_rows_for_gated_act_gemm (w13_interleaved )
1096+ w13_scale_interleaved = reorder_rows_for_gated_act_gemm (
1097+ w13_scale_interleaved
1098+ )
1099+ gemm1_weights_shuffled .append (
1100+ shuffle_matrix_a (w13_interleaved .view (torch .uint8 ), epilogue_tile_m )
1101+ .contiguous ()
1102+ .view (w13_q .dtype )
1103+ )
1104+ gemm2_weights_shuffled .append (
1105+ shuffle_matrix_a (w2_q [i ].view (torch .uint8 ), epilogue_tile_m )
1106+ .contiguous ()
1107+ .view (w2_q .dtype )
1108+ )
1109+
1110+ gemm1_scales_shuffled .append (
1111+ shuffle_matrix_sf_a (
1112+ w13_scale_interleaved .view (torch .uint8 ).reshape (w13_rows , - 1 ),
1113+ epilogue_tile_m ,
1114+ )
1115+ .contiguous ()
1116+ .view (w13_scale .dtype )
1117+ )
1118+ gemm2_scales_shuffled .append (
1119+ shuffle_matrix_sf_a (
1120+ w2_scale [i ].view (torch .uint8 ).reshape (hidden_size , - 1 ), epilogue_tile_m
1121+ )
1122+ .contiguous ()
1123+ .view (w2_scale .dtype )
1124+ )
1125+
1126+ out = trtllm_fp8_block_scale_moe (
1127+ routing_logits = router_logits_kernel ,
1128+ routing_bias = None ,
1129+ hidden_states = hidden_states_q ,
1130+ hidden_states_scale = hidden_states_scale ,
1131+ gemm1_weights = torch .stack (gemm1_weights_shuffled ),
1132+ gemm1_weights_scale = torch .stack (gemm1_scales_shuffled ),
1133+ gemm2_weights = torch .stack (gemm2_weights_shuffled ),
1134+ gemm2_weights_scale = torch .stack (gemm2_scales_shuffled ),
1135+ num_experts = num_experts ,
1136+ top_k = topk ,
1137+ n_group = None ,
1138+ topk_group = None ,
1139+ intermediate_size = intermediate_size ,
1140+ local_expert_offset = 0 ,
1141+ local_num_experts = num_experts ,
1142+ routed_scaling_factor = None ,
1143+ routing_method_type = 1 , # renormalize routing
1144+ use_shuffled_weight = True ,
1145+ weight_layout = 0 , # MajorK
1146+ fp8_quantization_type = Fp8QuantizationType .MxFp8 ,
1147+ )
1148+
1149+ # Block-scale MXFP8 kernels are approximate; require majority close.
1150+ check_accuracy (ref , out , atol = 0.1 , rtol = 0.85 , percent = 0.8 )
0 commit comments