|
11 | 11 |
|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
| 14 | +from collections.abc import Callable |
14 | 15 | from types import SimpleNamespace |
15 | 16 | from unittest.mock import MagicMock, patch |
16 | 17 |
|
17 | 18 | import mlx.core as mx |
18 | 19 | import torch |
19 | 20 | from vllm.sampling_params import SamplingParams |
20 | 21 |
|
| 22 | +import vllm_metal.paged_attention_common as pac |
21 | 23 | import vllm_metal.v1.model_runner as mr |
22 | 24 |
|
23 | 25 |
|
@@ -289,6 +291,73 @@ def _make_cached_scheduler_output( |
289 | 291 | ) |
290 | 292 |
|
291 | 293 |
|
| 294 | +class TestMixedDecodeAndPrefixHitPrefill: |
| 295 | + """Verify a decode request and a prefix-hit prefill in the same unified step.""" |
| 296 | + |
| 297 | + def test_decode_and_prefix_hit_prefill_produce_correct_state(self): |
| 298 | + runner = _make_paged_runner() |
| 299 | + prompt_a = [10, 20, 30] |
| 300 | + runner._request_states["req-A"] = mr.RequestState( |
| 301 | + token_ids=prompt_a + [99], |
| 302 | + prompt_len=len(prompt_a), |
| 303 | + cache=[], |
| 304 | + sampling_params=_greedy_sp(), |
| 305 | + generator=None, |
| 306 | + generated_tokens=1, |
| 307 | + block_ids=[0, 1], |
| 308 | + ) |
| 309 | + runner._paged_request_seq_lens["req-A"] = len(prompt_a) |
| 310 | + |
| 311 | + prompt_b = [1, 2, 3, 4, 5, 6] |
| 312 | + num_computed_b = 4 |
| 313 | + suffix_len_b = len(prompt_b) - num_computed_b |
| 314 | + logits = mx.zeros((1, 1 + suffix_len_b, 100)) |
| 315 | + runner.model.return_value = MagicMock(logits=logits) |
| 316 | + |
| 317 | + decode_token = 55 |
| 318 | + prefill_token = 77 |
| 319 | + # Decode is processed before prefill in execute_model; side_effect order matches. |
| 320 | + greedy_tokens = [mx.array([decode_token]), mx.array([prefill_token])] |
| 321 | + |
| 322 | + new_req_b = _make_new_req("req-B", prompt_b, num_computed_tokens=num_computed_b) |
| 323 | + sched_out = SimpleNamespace( |
| 324 | + scheduled_new_reqs=[new_req_b], |
| 325 | + scheduled_cached_reqs=SimpleNamespace( |
| 326 | + req_ids=["req-A"], |
| 327 | + new_block_ids=[None], |
| 328 | + resumed_req_ids=set(), |
| 329 | + num_computed_tokens=[len(prompt_a)], |
| 330 | + ), |
| 331 | + num_scheduled_tokens={"req-A": 1, "req-B": suffix_len_b}, |
| 332 | + total_num_scheduled_tokens=1 + suffix_len_b, |
| 333 | + finished_req_ids=set(), |
| 334 | + preempted_req_ids=set(), |
| 335 | + grammar_bitmask=None, |
| 336 | + ) |
| 337 | + |
| 338 | + with ( |
| 339 | + patch.object(mr.MetalModelRunner, "_extract_logits", return_value=logits), |
| 340 | + patch( |
| 341 | + "vllm_metal.v1.model_runner._mlx_greedy_sample", |
| 342 | + side_effect=greedy_tokens, |
| 343 | + ), |
| 344 | + patch("vllm_metal.v1.model_runner.prepare_unified"), |
| 345 | + patch("vllm_metal.v1.model_runner.clear_context"), |
| 346 | + ): |
| 347 | + runner.execute_model(sched_out) |
| 348 | + |
| 349 | + state_a = runner._request_states["req-A"] |
| 350 | + assert state_a.token_ids[-1] == decode_token |
| 351 | + assert state_a.generated_tokens == 2 |
| 352 | + |
| 353 | + state_b = runner._request_states.get("req-B") |
| 354 | + assert state_b is not None |
| 355 | + assert state_b.token_ids == prompt_b + [prefill_token] |
| 356 | + assert state_b.prompt_len == len(prompt_b) |
| 357 | + assert state_b.generated_tokens == 1 |
| 358 | + assert runner._paged_request_seq_lens.get("req-B") == len(prompt_b) |
| 359 | + |
| 360 | + |
292 | 361 | class TestCachedRequestContinuation: |
293 | 362 | """Verify the cached/intermediate-chunk path works with prefix offsets.""" |
294 | 363 |
|
@@ -348,3 +417,87 @@ def test_cached_intermediate_chunk_with_offset(self): |
348 | 417 | assert state.generated_tokens == len(state.token_ids) - state.prompt_len |
349 | 418 | # seq_lens must reflect full sequence |
350 | 419 | assert runner._paged_request_seq_lens["req-1"] == len(prompt) |
| 420 | + |
| 421 | + |
| 422 | +def _make_paged_ctx_spy( |
| 423 | + captured: list, |
| 424 | +) -> Callable[[pac.PagedAttentionContext], None]: |
| 425 | + def spy(ctx: pac.PagedAttentionContext) -> None: |
| 426 | + captured.append(ctx) |
| 427 | + pac._thread_local.paged_ctx = ctx |
| 428 | + |
| 429 | + return spy |
| 430 | + |
| 431 | + |
| 432 | +class TestPrepareUnifiedSlotMapping: |
| 433 | + """Verify prepare_unified is called with correct slot mapping and RoPE offsets. |
| 434 | +
|
| 435 | + All other tests in this file patch prepare_unified out. These tests let it |
| 436 | + run for real and spy on set_context to confirm the runner passes the right |
| 437 | + block_ids, num_tokens, and start_pos arguments so that slot mapping and RoPE |
| 438 | + offsets are exercised end-to-end. |
| 439 | + """ |
| 440 | + |
| 441 | + def test_fresh_prefill_slot_mapping_and_rope_offset(self): |
| 442 | + """start_pos == 0: slots cover positions 0..N-1, offset is 0.""" |
| 443 | + runner = _make_paged_runner() |
| 444 | + prompt = [10, 20, 30, 40] |
| 445 | + block_ids = [0] # block_size=4, block 0 covers positions 0-3 |
| 446 | + logits = mx.zeros((1, len(prompt), 100)) |
| 447 | + runner.model.return_value = MagicMock(logits=logits) |
| 448 | + |
| 449 | + captured: list[pac.PagedAttentionContext] = [] |
| 450 | + |
| 451 | + new_req = _make_new_req( |
| 452 | + "req-1", prompt, num_computed_tokens=0, block_ids=block_ids |
| 453 | + ) |
| 454 | + sched_out = _make_scheduler_output([new_req]) |
| 455 | + |
| 456 | + with ( |
| 457 | + patch.object(mr.MetalModelRunner, "_extract_logits", return_value=logits), |
| 458 | + patch( |
| 459 | + "vllm_metal.v1.model_runner._mlx_greedy_sample", |
| 460 | + return_value=mx.array([0]), |
| 461 | + ), |
| 462 | + patch.object(pac, "set_context", side_effect=_make_paged_ctx_spy(captured)), |
| 463 | + ): |
| 464 | + runner.execute_model(sched_out) |
| 465 | + |
| 466 | + assert len(captured) == 1 |
| 467 | + ctx = captured[0] |
| 468 | + assert ctx.slot_mapping == [0, 1, 2, 3] |
| 469 | + assert ctx.offsets == [0] |
| 470 | + assert ctx.context_lens == [4] |
| 471 | + |
| 472 | + def test_prefix_hit_slot_mapping_starts_at_start_pos(self): |
| 473 | + """start_pos == 2: slots cover positions 2-3, RoPE offset is 2.""" |
| 474 | + runner = _make_paged_runner() |
| 475 | + prompt = [10, 20, 30, 40] |
| 476 | + num_computed = 2 |
| 477 | + block_ids = [0] # block_size=4, block 0 covers positions 0-3 |
| 478 | + suffix_len = len(prompt) - num_computed |
| 479 | + logits = mx.zeros((1, suffix_len, 100)) |
| 480 | + runner.model.return_value = MagicMock(logits=logits) |
| 481 | + |
| 482 | + captured: list[pac.PagedAttentionContext] = [] |
| 483 | + |
| 484 | + new_req = _make_new_req( |
| 485 | + "req-1", prompt, num_computed_tokens=num_computed, block_ids=block_ids |
| 486 | + ) |
| 487 | + sched_out = _make_scheduler_output([new_req]) |
| 488 | + |
| 489 | + with ( |
| 490 | + patch.object(mr.MetalModelRunner, "_extract_logits", return_value=logits), |
| 491 | + patch( |
| 492 | + "vllm_metal.v1.model_runner._mlx_greedy_sample", |
| 493 | + return_value=mx.array([0]), |
| 494 | + ), |
| 495 | + patch.object(pac, "set_context", side_effect=_make_paged_ctx_spy(captured)), |
| 496 | + ): |
| 497 | + runner.execute_model(sched_out) |
| 498 | + |
| 499 | + assert len(captured) == 1 |
| 500 | + ctx = captured[0] |
| 501 | + assert ctx.slot_mapping == [2, 3] # positions 2-3 in block 0 |
| 502 | + assert ctx.offsets == [2] |
| 503 | + assert ctx.context_lens == [4] # start_pos + num_tokens = 2 + 2 |
0 commit comments