@@ -132,7 +132,12 @@ def test_verify_quantization_supported(self) -> None:
132132 def test_check_and_update_config_disables_chunked_prefill (
133133 self , monkeypatch : pytest .MonkeyPatch
134134 ) -> None :
135- """Metal should disable chunked prefill until the runner supports it."""
135+ """Metal should disable chunked prefill until the runner supports it.
136+
137+ When chunked prefill is disabled, max_num_batched_tokens must be at
138+ least max_model_len so the scheduler can schedule the entire prompt
139+ in a single step.
140+ """
136141 import vllm_metal .stt .config as stt_config
137142 import vllm_metal .utils as metal_utils
138143
@@ -150,22 +155,157 @@ def test_check_and_update_config_disables_chunked_prefill(
150155 model = "test-model" ,
151156 disable_cascade_attn = False ,
152157 tokenizer = None ,
158+ max_model_len = 32768 ,
153159 ),
154160 scheduler_config = SimpleNamespace (
155161 async_scheduling = True ,
156162 enable_chunked_prefill = True ,
163+ max_num_batched_tokens = 2048 ,
164+ max_num_scheduled_tokens = None ,
157165 ),
158166 )
159167
160168 MetalPlatform .check_and_update_config (vllm_config )
161169
162170 assert vllm_config .scheduler_config .enable_chunked_prefill is False
171+ assert vllm_config .scheduler_config .max_num_batched_tokens == 32768
163172 assert (
164173 vllm_config .parallel_config .worker_cls == "vllm_metal.v1.worker.MetalWorker"
165174 )
166175 assert vllm_config .parallel_config .distributed_executor_backend == "uni"
167176 assert vllm_config .parallel_config .disable_custom_all_reduce is True
168177
178+ def test_check_and_update_config_increases_max_num_scheduled_tokens_below_max_model_len (
179+ self , monkeypatch : pytest .MonkeyPatch
180+ ) -> None :
181+ """max_num_scheduled_tokens below max_model_len should be bumped up to max_model_len.
182+
183+ When max_num_scheduled_tokens is explicitly set to a value smaller
184+ than max_model_len, it must be raised to match max_model_len so that
185+ the scheduler can schedule the full prompt in a single step.
186+ """
187+ import vllm_metal .stt .config as stt_config
188+ import vllm_metal .utils as metal_utils
189+
190+ monkeypatch .setattr (metal_utils , "get_model_download_path" , lambda model : model )
191+ monkeypatch .setattr (stt_config , "is_stt_model" , lambda _model : False )
192+
193+ vllm_config = SimpleNamespace (
194+ parallel_config = SimpleNamespace (
195+ worker_cls = "auto" ,
196+ distributed_executor_backend = "auto" ,
197+ disable_custom_all_reduce = False ,
198+ ),
199+ cache_config = SimpleNamespace (block_size = None ),
200+ model_config = SimpleNamespace (
201+ model = "test-model" ,
202+ disable_cascade_attn = False ,
203+ tokenizer = None ,
204+ max_model_len = 32768 ,
205+ ),
206+ scheduler_config = SimpleNamespace (
207+ async_scheduling = True ,
208+ enable_chunked_prefill = True ,
209+ max_num_batched_tokens = 2048 ,
210+ max_num_scheduled_tokens = 2048 ,
211+ ),
212+ )
213+
214+ MetalPlatform .check_and_update_config (vllm_config )
215+
216+ assert vllm_config .scheduler_config .enable_chunked_prefill is False
217+ assert vllm_config .scheduler_config .max_num_batched_tokens == 32768
218+ assert vllm_config .scheduler_config .max_num_scheduled_tokens == 32768
219+
220+ def test_check_and_update_config_does_not_reduce_large_max_num_batched_tokens (
221+ self , monkeypatch : pytest .MonkeyPatch
222+ ) -> None :
223+ """max_num_batched_tokens must not be lowered when already >= max_model_len.
224+
225+ If the user has explicitly set a token budget larger than max_model_len,
226+ that setting must be preserved.
227+ """
228+ import vllm_metal .stt .config as stt_config
229+ import vllm_metal .utils as metal_utils
230+
231+ monkeypatch .setattr (metal_utils , "get_model_download_path" , lambda model : model )
232+ monkeypatch .setattr (stt_config , "is_stt_model" , lambda _model : False )
233+
234+ vllm_config = SimpleNamespace (
235+ parallel_config = SimpleNamespace (
236+ worker_cls = "auto" ,
237+ distributed_executor_backend = "auto" ,
238+ disable_custom_all_reduce = False ,
239+ ),
240+ cache_config = SimpleNamespace (block_size = None ),
241+ model_config = SimpleNamespace (
242+ model = "test-model" ,
243+ disable_cascade_attn = False ,
244+ tokenizer = None ,
245+ max_model_len = 32768 ,
246+ ),
247+ scheduler_config = SimpleNamespace (
248+ async_scheduling = True ,
249+ enable_chunked_prefill = True ,
250+ max_num_batched_tokens = 65536 ,
251+ max_num_scheduled_tokens = None ,
252+ ),
253+ )
254+
255+ MetalPlatform .check_and_update_config (vllm_config )
256+
257+ assert vllm_config .scheduler_config .enable_chunked_prefill is False
258+ # 65536 > 32768, so the value must stay at 65536
259+ assert vllm_config .scheduler_config .max_num_batched_tokens == 65536
260+
261+ @pytest .mark .parametrize ("max_num_scheduled_tokens" , [32768 , 65536 ])
262+ def test_check_and_update_config_does_not_reduce_max_num_scheduled_tokens_when_at_least_max_model_len (
263+ self ,
264+ monkeypatch : pytest .MonkeyPatch ,
265+ max_num_scheduled_tokens : int ,
266+ ) -> None :
267+ """max_num_scheduled_tokens must not be lowered when already >= max_model_len.
268+
269+ If the user has explicitly set a scheduled-token budget at least
270+ max_model_len, that setting must be preserved (only values strictly
271+ below max_model_len are bumped up).
272+ """
273+ import vllm_metal .stt .config as stt_config
274+ import vllm_metal .utils as metal_utils
275+
276+ monkeypatch .setattr (metal_utils , "get_model_download_path" , lambda model : model )
277+ monkeypatch .setattr (stt_config , "is_stt_model" , lambda _model : False )
278+
279+ vllm_config = SimpleNamespace (
280+ parallel_config = SimpleNamespace (
281+ worker_cls = "auto" ,
282+ distributed_executor_backend = "auto" ,
283+ disable_custom_all_reduce = False ,
284+ ),
285+ cache_config = SimpleNamespace (block_size = None ),
286+ model_config = SimpleNamespace (
287+ model = "test-model" ,
288+ disable_cascade_attn = False ,
289+ tokenizer = None ,
290+ max_model_len = 32768 ,
291+ ),
292+ scheduler_config = SimpleNamespace (
293+ async_scheduling = True ,
294+ enable_chunked_prefill = True ,
295+ max_num_batched_tokens = 65536 ,
296+ max_num_scheduled_tokens = max_num_scheduled_tokens ,
297+ ),
298+ )
299+
300+ MetalPlatform .check_and_update_config (vllm_config )
301+
302+ assert vllm_config .scheduler_config .enable_chunked_prefill is False
303+ assert vllm_config .scheduler_config .max_num_batched_tokens == 65536
304+ assert (
305+ vllm_config .scheduler_config .max_num_scheduled_tokens
306+ == max_num_scheduled_tokens
307+ )
308+
169309 def test_check_and_update_config_applies_stt_scheduler_policy (
170310 self , monkeypatch : pytest .MonkeyPatch
171311 ) -> None :
0 commit comments