Skip to content

Commit e4f04cd

Browse files
author
Shy Huang
committed
Replace flaky EP relative comparison with hardcoded absolute baseline
Signed-off-by: Shy Huang <shyhuang@google.com>
1 parent 799d14f commit e4f04cd

File tree

1 file changed

+66
-79
lines changed

1 file changed

+66
-79
lines changed

tests/e2e/test_expert_parallel.py

Lines changed: 66 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -8,38 +8,36 @@
88
import pytest
99
from vllm import LLM, EngineArgs, SamplingParams
1010

11+
# Hardcoded baseline times (seconds) for 512 prompts on TPU v7x-8.
12+
# Measured on 2026-03-20 with vllm LKG aa84e43c + tpu-inference main.
13+
# Tests pass if EP time is within REGRESSION_THRESHOLD of these baselines.
14+
EP_FUSED_BASELINE_TIME = 3.40
15+
EP_GMM_BASELINE_TIME = 2.07
16+
REGRESSION_THRESHOLD = 0.15 # 15% regression tolerance
17+
1118

1219
@dataclass
1320
class TestConfig:
1421
"""Configuration for EP test runs."""
1522
max_model_len: int = 512
16-
max_num_batched_tokens: int = 128
17-
max_num_seqs: int = 16
18-
num_prompts: int = 16
19-
20-
@classmethod
21-
def for_performance(cls) -> "TestConfig":
22-
return cls(
23-
max_model_len=512,
24-
max_num_batched_tokens=512,
25-
max_num_seqs=512,
26-
num_prompts=512,
27-
)
23+
max_num_batched_tokens: int = 512
24+
max_num_seqs: int = 512
25+
num_prompts: int = 512
2826

2927

3028
@dataclass
3129
class InferenceConfig:
3230
"""Configuration for a single inference run."""
3331
model_name: str
34-
tensor_parallel_size: int = 1
35-
enable_expert_parallel: bool = False
32+
tensor_parallel_size: int = 4
33+
enable_expert_parallel: bool = True
3634
max_model_len: int = 512
37-
max_num_batched_tokens: int = 128
38-
max_num_seqs: int = 16
35+
max_num_batched_tokens: int = 512
36+
max_num_seqs: int = 512
3937
gpu_memory_utilization: float = 0.95
4038

4139

42-
def generate_test_prompts(num_prompts: int = 256) -> list[str]:
40+
def generate_test_prompts(num_prompts: int = 512) -> list[str]:
4341
base_text = (
4442
"The rapid advancement of artificial intelligence has transformed "
4543
"numerous industries and continues to reshape our understanding of "
@@ -87,6 +85,9 @@ def _run_inference(
8785
engine_args_dict = asdict(engine_args)
8886
llm = LLM(**engine_args_dict)
8987

88+
# Warmup
89+
llm.generate(test_prompts[:8], sampling_params)
90+
9091
start_time = time.time()
9192
outputs = llm.generate(test_prompts, sampling_params)
9293
elapsed_time = time.time() - start_time
@@ -96,92 +97,63 @@ def _run_inference(
9697
return outputs, elapsed_time
9798

9899

99-
def _check_performance(
100+
def _check_no_regression(
100101
test_name: str,
101102
baseline_time: float,
102-
ep_time: float,
103+
actual_time: float,
103104
num_prompts: int,
104-
min_speedup: float,
105+
threshold: float = REGRESSION_THRESHOLD,
105106
):
106-
"""Verify expert parallelism provides expected speedup."""
107-
speedup = baseline_time / ep_time if ep_time > 0 else 0
107+
"""Verify EP time has not regressed beyond threshold vs hardcoded baseline."""
108+
max_allowed = baseline_time * (1 + threshold)
109+
regression_pct = ((actual_time - baseline_time) / baseline_time) * 100
108110

109-
print(f"{test_name} performance test results:")
111+
print(f"\n{test_name} performance results:")
110112
print(f" Number of prompts: {num_prompts}")
111113
print(f" Baseline time: {baseline_time:.2f}s")
112-
print(f" Expert parallel time: {ep_time:.2f}s")
113-
print(f" Speedup: {speedup:.2f}x")
114-
print(f" Baseline throughput: {num_prompts/baseline_time:.2f} prompts/s")
115-
print(f" Expert parallel throughput: {num_prompts/ep_time:.2f} prompts/s")
114+
print(f" Actual time: {actual_time:.2f}s")
115+
print(
116+
f" Max allowed: {max_allowed:.2f}s (baseline + {threshold*100:.0f}%)"
117+
)
118+
print(f" Delta: {regression_pct:+.1f}%")
119+
print(f" Throughput: {num_prompts/actual_time:.2f} prompts/s")
116120

117-
assert speedup >= min_speedup, (
118-
f"Expert parallelism did not provide expected speedup "
119-
f"({min_speedup:.2f}x): {speedup:.2f}x")
121+
assert actual_time <= max_allowed, (
122+
f"{test_name} regressed by {regression_pct:.1f}% "
123+
f"(actual {actual_time:.2f}s > allowed {max_allowed:.2f}s)")
120124

121125

122-
def _test_expert_parallelism_performance(
123-
sampling_params: SamplingParams,
124-
use_fused_kernel: bool,
125-
model_name: str | None = None,
126-
):
127-
"""Performance test for expert parallelism."""
128-
if model_name is None:
129-
model_name = os.environ.get("EP_MODEL_NAME", "Qwen/Qwen1.5-MoE-A2.7B")
126+
def test_ep_fused_performance(sampling_params: SamplingParams):
127+
"""Test EP with fused MoE kernel does not regress vs hardcoded baseline."""
128+
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
129+
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
130+
os.environ['USE_MOE_EP_KERNEL'] = '1'
130131

131-
cfg = TestConfig.for_performance()
132+
model_name = os.environ.get("EP_MODEL_NAME", "Qwen/Qwen1.5-MoE-A2.7B")
133+
cfg = TestConfig()
132134
test_prompts = generate_test_prompts(cfg.num_prompts)
133135

134-
if use_fused_kernel:
135-
os.environ['USE_MOE_EP_KERNEL'] = '1'
136-
137136
try:
138-
# Run EP (TP=4 + EP)
139137
ep_config = InferenceConfig(
140138
model_name=model_name,
141-
tensor_parallel_size=4,
142-
enable_expert_parallel=True,
143139
max_model_len=cfg.max_model_len,
144140
max_num_batched_tokens=cfg.max_num_batched_tokens,
145141
max_num_seqs=cfg.max_num_seqs,
146142
)
147143
_, ep_time = _run_inference(ep_config, test_prompts, sampling_params)
148144

149-
# Run baseline (TP=1)
150-
baseline_config = InferenceConfig(
151-
model_name=model_name,
152-
tensor_parallel_size=1,
153-
enable_expert_parallel=False,
154-
max_model_len=cfg.max_model_len,
155-
max_num_batched_tokens=cfg.max_num_batched_tokens,
156-
max_num_seqs=cfg.max_num_seqs,
157-
)
158-
_, baseline_time = _run_inference(baseline_config, test_prompts,
159-
sampling_params)
160-
161-
kernel_name = "EP Fused" if use_fused_kernel else "EP GMM"
162-
_check_performance(
163-
f"Expert parallelism ({kernel_name})",
164-
baseline_time,
145+
_check_no_regression(
146+
"EP Fused",
147+
EP_FUSED_BASELINE_TIME,
165148
ep_time,
166-
len(test_prompts),
167-
min_speedup=0.6,
149+
cfg.num_prompts,
168150
)
169151
finally:
170-
if use_fused_kernel:
171-
del os.environ['USE_MOE_EP_KERNEL']
172-
173-
174-
def test_ep_fused_performance(sampling_params: SamplingParams):
175-
"""Test expert parallelism performance with fused MoE EP kernel."""
176-
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
177-
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
178-
179-
_test_expert_parallelism_performance(sampling_params,
180-
use_fused_kernel=True)
152+
del os.environ['USE_MOE_EP_KERNEL']
181153

182154

183155
def test_ep_gmm_performance(sampling_params: SamplingParams):
184-
"""Test expert parallelism performance with GMM kernel.
156+
"""Test EP with GMM kernel does not regress vs hardcoded baseline.
185157
186158
Uses OLMoE-1B-7B (64 experts, power-of-2) instead of Qwen2MoE
187159
(60 experts) because the GMM EP kernel requires num_tokens*topk
@@ -191,7 +163,22 @@ def test_ep_gmm_performance(sampling_params: SamplingParams):
191163
os.environ['SKIP_JAX_PRECOMPILE'] = '0'
192164
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '1'
193165

194-
gmm_model = os.environ.get("EP_GMM_MODEL_NAME", "allenai/OLMoE-1B-7B-0924")
195-
_test_expert_parallelism_performance(sampling_params,
196-
use_fused_kernel=False,
197-
model_name=gmm_model)
166+
model_name = os.environ.get("EP_GMM_MODEL_NAME",
167+
"allenai/OLMoE-1B-7B-0924")
168+
cfg = TestConfig()
169+
test_prompts = generate_test_prompts(cfg.num_prompts)
170+
171+
ep_config = InferenceConfig(
172+
model_name=model_name,
173+
max_model_len=cfg.max_model_len,
174+
max_num_batched_tokens=cfg.max_num_batched_tokens,
175+
max_num_seqs=cfg.max_num_seqs,
176+
)
177+
_, ep_time = _run_inference(ep_config, test_prompts, sampling_params)
178+
179+
_check_no_regression(
180+
"EP GMM",
181+
EP_GMM_BASELINE_TIME,
182+
ep_time,
183+
cfg.num_prompts,
184+
)

0 commit comments

Comments
 (0)