88import pytest
99from vllm import LLM , EngineArgs , SamplingParams
1010
11+ # Hardcoded baseline times (seconds) for 512 prompts on TPU v7x-8.
12+ # Measured on 2026-03-20 with vllm LKG aa84e43c + tpu-inference main.
13+ # Tests pass if EP time is within REGRESSION_THRESHOLD of these baselines.
14+ EP_FUSED_BASELINE_TIME = 3.40
15+ EP_GMM_BASELINE_TIME = 2.07
16+ REGRESSION_THRESHOLD = 0.15 # 15% regression tolerance
17+
1118
1219@dataclass
1320class TestConfig :
1421 """Configuration for EP test runs."""
1522 max_model_len : int = 512
16- max_num_batched_tokens : int = 128
17- max_num_seqs : int = 16
18- num_prompts : int = 16
19-
20- @classmethod
21- def for_performance (cls ) -> "TestConfig" :
22- return cls (
23- max_model_len = 512 ,
24- max_num_batched_tokens = 512 ,
25- max_num_seqs = 512 ,
26- num_prompts = 512 ,
27- )
23+ max_num_batched_tokens : int = 512
24+ max_num_seqs : int = 512
25+ num_prompts : int = 512
2826
2927
3028@dataclass
3129class InferenceConfig :
3230 """Configuration for a single inference run."""
3331 model_name : str
34- tensor_parallel_size : int = 1
35- enable_expert_parallel : bool = False
32+ tensor_parallel_size : int = 4
33+ enable_expert_parallel : bool = True
3634 max_model_len : int = 512
37- max_num_batched_tokens : int = 128
38- max_num_seqs : int = 16
35+ max_num_batched_tokens : int = 512
36+ max_num_seqs : int = 512
3937 gpu_memory_utilization : float = 0.95
4038
4139
42- def generate_test_prompts (num_prompts : int = 256 ) -> list [str ]:
40+ def generate_test_prompts (num_prompts : int = 512 ) -> list [str ]:
4341 base_text = (
4442 "The rapid advancement of artificial intelligence has transformed "
4543 "numerous industries and continues to reshape our understanding of "
@@ -87,6 +85,9 @@ def _run_inference(
8785 engine_args_dict = asdict (engine_args )
8886 llm = LLM (** engine_args_dict )
8987
88+ # Warmup
89+ llm .generate (test_prompts [:8 ], sampling_params )
90+
9091 start_time = time .time ()
9192 outputs = llm .generate (test_prompts , sampling_params )
9293 elapsed_time = time .time () - start_time
@@ -96,92 +97,63 @@ def _run_inference(
9697 return outputs , elapsed_time
9798
9899
99- def _check_performance (
100+ def _check_no_regression (
100101 test_name : str ,
101102 baseline_time : float ,
102- ep_time : float ,
103+ actual_time : float ,
103104 num_prompts : int ,
104- min_speedup : float ,
105+ threshold : float = REGRESSION_THRESHOLD ,
105106):
106- """Verify expert parallelism provides expected speedup."""
107- speedup = baseline_time / ep_time if ep_time > 0 else 0
107+ """Verify EP time has not regressed beyond threshold vs hardcoded baseline."""
108+ max_allowed = baseline_time * (1 + threshold )
109+ regression_pct = ((actual_time - baseline_time ) / baseline_time ) * 100
108110
109- print (f"✓ { test_name } performance test results:" )
111+ print (f"\n { test_name } performance results:" )
110112 print (f" Number of prompts: { num_prompts } " )
111113 print (f" Baseline time: { baseline_time :.2f} s" )
112- print (f" Expert parallel time: { ep_time :.2f} s" )
113- print (f" Speedup: { speedup :.2f} x" )
114- print (f" Baseline throughput: { num_prompts / baseline_time :.2f} prompts/s" )
115- print (f" Expert parallel throughput: { num_prompts / ep_time :.2f} prompts/s" )
114+ print (f" Actual time: { actual_time :.2f} s" )
115+ print (
116+ f" Max allowed: { max_allowed :.2f} s (baseline + { threshold * 100 :.0f} %)"
117+ )
118+ print (f" Delta: { regression_pct :+.1f} %" )
119+ print (f" Throughput: { num_prompts / actual_time :.2f} prompts/s" )
116120
117- assert speedup >= min_speedup , (
118- f"Expert parallelism did not provide expected speedup "
119- f"({ min_speedup :.2f} x): { speedup :.2f} x " )
121+ assert actual_time <= max_allowed , (
122+ f"{ test_name } regressed by { regression_pct :.1f } % "
123+ f"(actual { actual_time :.2f} s > allowed { max_allowed :.2f} s) " )
120124
121125
122- def _test_expert_parallelism_performance (
123- sampling_params : SamplingParams ,
124- use_fused_kernel : bool ,
125- model_name : str | None = None ,
126- ):
127- """Performance test for expert parallelism."""
128- if model_name is None :
129- model_name = os .environ .get ("EP_MODEL_NAME" , "Qwen/Qwen1.5-MoE-A2.7B" )
126+ def test_ep_fused_performance (sampling_params : SamplingParams ):
127+ """Test EP with fused MoE kernel does not regress vs hardcoded baseline."""
128+ os .environ ['SKIP_JAX_PRECOMPILE' ] = '0'
129+ os .environ ['VLLM_XLA_CHECK_RECOMPILATION' ] = '1'
130+ os .environ ['USE_MOE_EP_KERNEL' ] = '1'
130131
131- cfg = TestConfig .for_performance ()
132+ model_name = os .environ .get ("EP_MODEL_NAME" , "Qwen/Qwen1.5-MoE-A2.7B" )
133+ cfg = TestConfig ()
132134 test_prompts = generate_test_prompts (cfg .num_prompts )
133135
134- if use_fused_kernel :
135- os .environ ['USE_MOE_EP_KERNEL' ] = '1'
136-
137136 try :
138- # Run EP (TP=4 + EP)
139137 ep_config = InferenceConfig (
140138 model_name = model_name ,
141- tensor_parallel_size = 4 ,
142- enable_expert_parallel = True ,
143139 max_model_len = cfg .max_model_len ,
144140 max_num_batched_tokens = cfg .max_num_batched_tokens ,
145141 max_num_seqs = cfg .max_num_seqs ,
146142 )
147143 _ , ep_time = _run_inference (ep_config , test_prompts , sampling_params )
148144
149- # Run baseline (TP=1)
150- baseline_config = InferenceConfig (
151- model_name = model_name ,
152- tensor_parallel_size = 1 ,
153- enable_expert_parallel = False ,
154- max_model_len = cfg .max_model_len ,
155- max_num_batched_tokens = cfg .max_num_batched_tokens ,
156- max_num_seqs = cfg .max_num_seqs ,
157- )
158- _ , baseline_time = _run_inference (baseline_config , test_prompts ,
159- sampling_params )
160-
161- kernel_name = "EP Fused" if use_fused_kernel else "EP GMM"
162- _check_performance (
163- f"Expert parallelism ({ kernel_name } )" ,
164- baseline_time ,
145+ _check_no_regression (
146+ "EP Fused" ,
147+ EP_FUSED_BASELINE_TIME ,
165148 ep_time ,
166- len (test_prompts ),
167- min_speedup = 0.6 ,
149+ cfg .num_prompts ,
168150 )
169151 finally :
170- if use_fused_kernel :
171- del os .environ ['USE_MOE_EP_KERNEL' ]
172-
173-
174- def test_ep_fused_performance (sampling_params : SamplingParams ):
175- """Test expert parallelism performance with fused MoE EP kernel."""
176- os .environ ['SKIP_JAX_PRECOMPILE' ] = '0'
177- os .environ ['VLLM_XLA_CHECK_RECOMPILATION' ] = '1'
178-
179- _test_expert_parallelism_performance (sampling_params ,
180- use_fused_kernel = True )
152+ del os .environ ['USE_MOE_EP_KERNEL' ]
181153
182154
183155def test_ep_gmm_performance (sampling_params : SamplingParams ):
184- """Test expert parallelism performance with GMM kernel.
156+ """Test EP with GMM kernel does not regress vs hardcoded baseline .
185157
186158 Uses OLMoE-1B-7B (64 experts, power-of-2) instead of Qwen2MoE
187159 (60 experts) because the GMM EP kernel requires num_tokens*topk
@@ -191,7 +163,22 @@ def test_ep_gmm_performance(sampling_params: SamplingParams):
191163 os .environ ['SKIP_JAX_PRECOMPILE' ] = '0'
192164 os .environ ['VLLM_XLA_CHECK_RECOMPILATION' ] = '1'
193165
194- gmm_model = os .environ .get ("EP_GMM_MODEL_NAME" , "allenai/OLMoE-1B-7B-0924" )
195- _test_expert_parallelism_performance (sampling_params ,
196- use_fused_kernel = False ,
197- model_name = gmm_model )
166+ model_name = os .environ .get ("EP_GMM_MODEL_NAME" ,
167+ "allenai/OLMoE-1B-7B-0924" )
168+ cfg = TestConfig ()
169+ test_prompts = generate_test_prompts (cfg .num_prompts )
170+
171+ ep_config = InferenceConfig (
172+ model_name = model_name ,
173+ max_model_len = cfg .max_model_len ,
174+ max_num_batched_tokens = cfg .max_num_batched_tokens ,
175+ max_num_seqs = cfg .max_num_seqs ,
176+ )
177+ _ , ep_time = _run_inference (ep_config , test_prompts , sampling_params )
178+
179+ _check_no_regression (
180+ "EP GMM" ,
181+ EP_GMM_BASELINE_TIME ,
182+ ep_time ,
183+ cfg .num_prompts ,
184+ )
0 commit comments