Skip to content

Commit 7aa5332

Browse files
author
1113778
committed
Add ModelOpt MXFP8 MoE support (PR vllm-project#35986)
Apply upstream patch from vllm-project#35986 to add support for ModelOpt MXFP8 MoE models using FlashInfer 0.6.4's new TRTLLM MoE kernel. Key changes: - Add MXFP8 MoE backend selection in oracle/mxfp8.py - Update modelopt quantization layer to support MXFP8 MoE - Fix weight scale loading to distinguish block scales from per-tensor scales - Add comprehensive unit tests for MXFP8 MoE functionality This enables ~1.6x throughput improvement for MoE models with MXFP8 quantization while maintaining accuracy. Made-with: Cursor
1 parent e7556ac commit 7aa5332

5 files changed

Lines changed: 939 additions & 18 deletions

File tree

tests/kernels/moe/test_ocp_mx_moe.py

Lines changed: 185 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
current_platform.is_cuda() and current_platform.is_device_capability_family(100)
2121
)
2222

23+
TRTLLM_GEN_MXFP8_AVAILABLE = TRTLLM_GEN_MXFP4_AVAILABLE
24+
2325
HOPPER_MXFP4_BF16_AVAILABLE = (
2426
current_platform.is_cuda()
2527
and current_platform.is_device_capability(90)
@@ -34,9 +36,15 @@
3436
shuffle_matrix_a,
3537
shuffle_matrix_sf_a,
3638
trtllm_fp4_block_scale_moe,
39+
trtllm_fp8_block_scale_moe,
3740
)
3841
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
39-
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
42+
43+
if TRTLLM_GEN_MXFP8_AVAILABLE:
44+
from flashinfer.fused_moe.core import (
45+
Fp8QuantizationType,
46+
get_w2_permute_indices_with_cache,
47+
)
4048

4149

4250
@dataclass
@@ -160,6 +168,7 @@ def reference_moe(
160168
beta,
161169
limit,
162170
act_type,
171+
is_gated,
163172
):
164173
# renormalize routing
165174
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
@@ -170,7 +179,12 @@ def reference_moe(
170179
mlp1_weight = w13[expert_indices, ...]
171180
mlp1_bias = bias13[expert_indices, ...]
172181
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
173-
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
182+
if is_gated:
183+
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
184+
else:
185+
# RELU2_NO_MUL: relu(x)^2
186+
t = torch.relu(t)
187+
t = t * t
174188

175189
if act_type == "mxfp8":
176190
t_quantized, t_scale = mxfp8_quantize(
@@ -569,6 +583,7 @@ def test_trtllm_gen_mxfp4_fused_moe(
569583
beta,
570584
limit,
571585
act_type,
586+
is_gated=True,
572587
)
573588
ref_result[start_idx:end_idx].copy_(chunk_result)
574589

@@ -705,6 +720,7 @@ def test_flashinfer_cutlass_mxfp4_fused_moe(
705720
beta,
706721
limit,
707722
"bf16",
723+
is_gated=True,
708724
)
709725

710726
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
@@ -890,6 +906,7 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
890906
beta,
891907
limit,
892908
"mxfp8",
909+
is_gated=True,
893910
)
894911

895912
# Prepare inputs for FlashInfer CUTLASS fused MoE
@@ -965,3 +982,169 @@ def dequant_mxfp4_batches(mat_fp4: torch.Tensor, scale_tensor: torch.Tensor):
965982

966983
# Allow some mismatch due to MXFP4 quantization
967984
check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8)
985+
986+
987+
@pytest.mark.parametrize("topk", [1, 4])
988+
@pytest.mark.parametrize("num_experts", [32])
989+
@pytest.mark.parametrize("num_tokens", [1, 128])
990+
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
991+
@pytest.mark.parametrize("is_gated", [True], ids=["gated"])
992+
@pytest.mark.skipif(
993+
not TRTLLM_GEN_MXFP8_AVAILABLE,
994+
reason="nvidia gpu and compute capability sm100 is required for this test",
995+
)
996+
def test_trtllm_gen_mxfp8_block_scale_moe(
997+
topk: int,
998+
num_experts: int,
999+
num_tokens: int,
1000+
intermediate_size: int,
1001+
hidden_size: int,
1002+
is_gated: bool,
1003+
):
1004+
torch.manual_seed(42)
1005+
device = "cuda:0"
1006+
1007+
inter_size = intermediate_size * (2 if is_gated else 1)
1008+
1009+
hidden_states = (
1010+
torch.randn(num_tokens, hidden_size, device=device, dtype=torch.bfloat16) / 20
1011+
)
1012+
w13 = (
1013+
torch.randn(
1014+
num_experts,
1015+
inter_size,
1016+
hidden_size,
1017+
device=device,
1018+
dtype=torch.bfloat16,
1019+
)
1020+
/ 20
1021+
)
1022+
w2 = (
1023+
torch.randn(
1024+
num_experts,
1025+
hidden_size,
1026+
intermediate_size,
1027+
device=device,
1028+
dtype=torch.bfloat16,
1029+
)
1030+
/ 20
1031+
)
1032+
router_logits = torch.rand(
1033+
num_tokens, num_experts, dtype=torch.float32, device=device
1034+
)
1035+
router_logits_kernel = router_logits.to(torch.bfloat16)
1036+
1037+
# Quantize weights to MXFP8 and normalize scales to [E, M, K//32].
1038+
w13_q, w13_scale = mxfp8_quantize(w13, is_sf_swizzled_layout=False)
1039+
w2_q, w2_scale = mxfp8_quantize(w2, is_sf_swizzled_layout=False)
1040+
if w13_scale.ndim == 1:
1041+
w13_scale = w13_scale.view(
1042+
num_experts,
1043+
inter_size,
1044+
hidden_size // 32,
1045+
)
1046+
if w2_scale.ndim == 1:
1047+
w2_scale = w2_scale.view(num_experts, hidden_size, intermediate_size // 32)
1048+
1049+
# Quantize activations to MXFP8.
1050+
hidden_states_q, hidden_states_scale = mxfp8_quantize(
1051+
hidden_states, is_sf_swizzled_layout=False
1052+
)
1053+
if hidden_states_scale.ndim == 1:
1054+
hidden_states_scale = hidden_states_scale.view(num_tokens, hidden_size // 32)
1055+
1056+
# Reference output using dequantized tensors + MXFP8 intermediate quantization.
1057+
w13_ref = mxfp8_dequantize(w13_q, w13_scale).to(torch.float32)
1058+
w2_ref = mxfp8_dequantize(w2_q, w2_scale).to(torch.float32)
1059+
hidden_states_ref = mxfp8_dequantize(hidden_states_q, hidden_states_scale).to(
1060+
torch.float32
1061+
)
1062+
bias13 = torch.zeros(
1063+
num_experts,
1064+
intermediate_size * (2 if is_gated else 1),
1065+
device=device,
1066+
)
1067+
bias2 = torch.zeros(num_experts, hidden_size, device=device)
1068+
ref = reference_moe(
1069+
router_logits_kernel.to(torch.float32),
1070+
topk,
1071+
num_experts,
1072+
hidden_states_ref,
1073+
w13_ref,
1074+
bias13,
1075+
w2_ref,
1076+
bias2,
1077+
alpha=1.0,
1078+
beta=0.0,
1079+
limit=None,
1080+
act_type="mxfp8",
1081+
is_gated=is_gated,
1082+
)
1083+
1084+
# Shuffle weights/scales with the same indexed layout used by TRTLLM kernels.
1085+
epilogue_tile_m = 128
1086+
gemm1_weights_shuffled = []
1087+
gemm1_scales_shuffled = []
1088+
gemm2_weights_shuffled = []
1089+
gemm2_scales_shuffled = []
1090+
for i in range(num_experts):
1091+
w13_rows = intermediate_size * (2 if is_gated else 1)
1092+
w13_interleaved = w13_q[i].clone().reshape(w13_rows, -1)
1093+
w13_scale_interleaved = w13_scale[i].clone().reshape(w13_rows, -1)
1094+
if is_gated:
1095+
w13_interleaved = reorder_rows_for_gated_act_gemm(w13_interleaved)
1096+
w13_scale_interleaved = reorder_rows_for_gated_act_gemm(
1097+
w13_scale_interleaved
1098+
)
1099+
gemm1_weights_shuffled.append(
1100+
shuffle_matrix_a(w13_interleaved.view(torch.uint8), epilogue_tile_m)
1101+
.contiguous()
1102+
.view(w13_q.dtype)
1103+
)
1104+
gemm2_weights_shuffled.append(
1105+
shuffle_matrix_a(w2_q[i].view(torch.uint8), epilogue_tile_m)
1106+
.contiguous()
1107+
.view(w2_q.dtype)
1108+
)
1109+
1110+
gemm1_scales_shuffled.append(
1111+
shuffle_matrix_sf_a(
1112+
w13_scale_interleaved.view(torch.uint8).reshape(w13_rows, -1),
1113+
epilogue_tile_m,
1114+
)
1115+
.contiguous()
1116+
.view(w13_scale.dtype)
1117+
)
1118+
gemm2_scales_shuffled.append(
1119+
shuffle_matrix_sf_a(
1120+
w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), epilogue_tile_m
1121+
)
1122+
.contiguous()
1123+
.view(w2_scale.dtype)
1124+
)
1125+
1126+
out = trtllm_fp8_block_scale_moe(
1127+
routing_logits=router_logits_kernel,
1128+
routing_bias=None,
1129+
hidden_states=hidden_states_q,
1130+
hidden_states_scale=hidden_states_scale,
1131+
gemm1_weights=torch.stack(gemm1_weights_shuffled),
1132+
gemm1_weights_scale=torch.stack(gemm1_scales_shuffled),
1133+
gemm2_weights=torch.stack(gemm2_weights_shuffled),
1134+
gemm2_weights_scale=torch.stack(gemm2_scales_shuffled),
1135+
num_experts=num_experts,
1136+
top_k=topk,
1137+
n_group=None,
1138+
topk_group=None,
1139+
intermediate_size=intermediate_size,
1140+
local_expert_offset=0,
1141+
local_num_experts=num_experts,
1142+
routed_scaling_factor=None,
1143+
routing_method_type=1, # renormalize routing
1144+
use_shuffled_weight=True,
1145+
weight_layout=0, # MajorK
1146+
fp8_quantization_type=Fp8QuantizationType.MxFp8,
1147+
)
1148+
1149+
# Block-scale MXFP8 kernels are approximate; require majority close.
1150+
check_accuracy(ref, out, atol=0.1, rtol=0.85, percent=0.8)

0 commit comments

Comments
 (0)