Skip to content

Commit bf47320

Browse files
elvischenvmtparet
authored andcommitted
[Perf] Eliminate padding and slicing op for GPT-OSS with Flashinfer MXFP4 MXFP8 MoE (vllm-project#30647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
1 parent 0f3d3e3 commit bf47320

6 files changed

Lines changed: 40 additions & 3 deletions

File tree

tests/compile/fusions_e2e/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def run(
8282
f"attention backend '{attn_backend.backend.name}'"
8383
)
8484

85+
# TODO: remove this after finishing migration from envs to model kwargs
86+
if model_name == "openai/gpt-oss-20b":
87+
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
88+
8589
# Disable, compile cache to make sure custom passes run.
8690
# Otherwise, we can't verify fusion happened through the logs.
8791
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

tests/compile/fusions_e2e/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,12 @@
162162
# async_tp=n_layers * 2,
163163
),
164164
)
165+
166+
gpt_oss_20b = ModelFusionInfo(
167+
model_name="openai/gpt-oss-20b",
168+
matches=lambda n_layers: Matches(
169+
ar_rms_fusion=n_layers * 2 + 1,
170+
sequence_parallel=n_layers * 2 + 1,
171+
async_tp=n_layers * 2,
172+
),
173+
)

tests/compile/fusions_e2e/test_tp2_ar_rms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FLASHINFER_MLA_ATTN,
2121
TRITON_ATTN,
2222
deepseek_v3_fp8,
23+
gpt_oss_20b,
2324
llama3_8b,
2425
llama3_8b_fp4,
2526
llama3_8b_fp8,
@@ -158,7 +159,7 @@ def test_tp2_ar_rms_fp4_fusions(
158159
@multi_gpu_test(num_gpus=2)
159160
@pytest.mark.parametrize(
160161
"model_name, matches_fn, model_kwargs, hf_overrides",
161-
[llama3_8b, qwen3_a3b],
162+
[llama3_8b, qwen3_a3b, gpt_oss_20b],
162163
)
163164
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN])
164165
@pytest.mark.parametrize("n_layers", [4])

vllm/model_executor/layers/fused_moe/fused_moe_method_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def topk_indices_dtype(self) -> torch.dtype | None:
101101
return self.moe_kernel.prepare_finalize.topk_indices_dtype()
102102
return None
103103

104+
@property
105+
def skip_forward_padding(self) -> bool:
106+
"""Whether to skip the padding in the forward before applying the moe method."""
107+
return False
108+
104109
@property
105110
def supports_eplb(self) -> bool:
106111
return False

vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,10 @@ def forward(
415415

416416
# This is the dimension after transform (for routed expert output slicing)
417417
transformed_hidden_dim = hidden_states.shape[-1]
418-
if self.moe_config.hidden_dim != transformed_hidden_dim:
418+
if (
419+
not self.quant_method.skip_forward_padding
420+
and self.moe_config.hidden_dim != transformed_hidden_dim
421+
):
419422
hidden_states = F.pad(
420423
hidden_states,
421424
(0, self.moe_config.hidden_dim - transformed_hidden_dim),

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def __init__(self, moe: FusedMoEConfig):
294294
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
295295
self.moe_kernel: mk.FusedMoEKernel | None = None
296296

297+
@property
298+
def skip_forward_padding(self) -> bool:
299+
# SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant
300+
# so can skip the padding in the forward before applying the moe method
301+
return self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
302+
297303
def create_weights(
298304
self,
299305
layer: torch.nn.Module,
@@ -1130,9 +1136,17 @@ def apply_monolithic(
11301136
elif self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM:
11311137
from flashinfer import mxfp8_quantize
11321138

1133-
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
1139+
# x_quant is padded in hidden dimension with alignment=256
1140+
x_quant, x_scale = mxfp8_quantize(
1141+
x,
1142+
is_sf_swizzled_layout=False,
1143+
alignment=256,
1144+
)
11341145
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
11351146

1147+
# output with original unpadded hidden size
1148+
output = torch.empty_like(x)
1149+
11361150
trtllm_gen_output = trtllm_fp4_block_scale_moe(
11371151
routing_logits=router_logits.to(torch.bfloat16),
11381152
routing_bias=None,
@@ -1161,6 +1175,7 @@ def apply_monolithic(
11611175
routing_method_type=1 if layer.renormalize else 0,
11621176
do_finalize=True,
11631177
tune_max_num_tokens=max(self.max_capture_size, 1),
1178+
output=output,
11641179
)[0]
11651180
return trtllm_gen_output
11661181
elif self.mxfp4_backend == Mxfp4Backend.CK:

0 commit comments

Comments
 (0)