Skip to content

Commit eb49908

Browse files
authored
Fix full_prompt derivation and add paged path tests (#208)
This PR is: - To fix the full_prompt derivation bug in the paged prefill path where a prefix-cache-hit request on its first chunk was silently dropped (no RequestState, not in new_reqs_by_id) - To replace a brittle chained ternary with an explicit if/elif/else that fails loudly on impossible states - To add a warning when a paged cached request is missing its RequestState (surfaces state tracking bugs instead of silently emitting a placeholder) - To add test coverage for mixed decode + prefix-hit prefill in a single unified step, and for slot mapping / rope offset correctness under fresh and prefix-hit conditions --------- Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
1 parent 19a19c4 commit eb49908

File tree

2 files changed

+178
-13
lines changed

2 files changed

+178
-13
lines changed

tests/test_paged_prefix_caching.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212
from __future__ import annotations
1313

14+
from collections.abc import Callable
1415
from types import SimpleNamespace
1516
from unittest.mock import MagicMock, patch
1617

1718
import mlx.core as mx
1819
import torch
1920
from vllm.sampling_params import SamplingParams
2021

22+
import vllm_metal.paged_attention_common as pac
2123
import vllm_metal.v1.model_runner as mr
2224

2325

@@ -289,6 +291,73 @@ def _make_cached_scheduler_output(
289291
)
290292

291293

294+
class TestMixedDecodeAndPrefixHitPrefill:
295+
"""Verify a decode request and a prefix-hit prefill in the same unified step."""
296+
297+
def test_decode_and_prefix_hit_prefill_produce_correct_state(self):
298+
runner = _make_paged_runner()
299+
prompt_a = [10, 20, 30]
300+
runner._request_states["req-A"] = mr.RequestState(
301+
token_ids=prompt_a + [99],
302+
prompt_len=len(prompt_a),
303+
cache=[],
304+
sampling_params=_greedy_sp(),
305+
generator=None,
306+
generated_tokens=1,
307+
block_ids=[0, 1],
308+
)
309+
runner._paged_request_seq_lens["req-A"] = len(prompt_a)
310+
311+
prompt_b = [1, 2, 3, 4, 5, 6]
312+
num_computed_b = 4
313+
suffix_len_b = len(prompt_b) - num_computed_b
314+
logits = mx.zeros((1, 1 + suffix_len_b, 100))
315+
runner.model.return_value = MagicMock(logits=logits)
316+
317+
decode_token = 55
318+
prefill_token = 77
319+
# Decode is processed before prefill in execute_model; side_effect order matches.
320+
greedy_tokens = [mx.array([decode_token]), mx.array([prefill_token])]
321+
322+
new_req_b = _make_new_req("req-B", prompt_b, num_computed_tokens=num_computed_b)
323+
sched_out = SimpleNamespace(
324+
scheduled_new_reqs=[new_req_b],
325+
scheduled_cached_reqs=SimpleNamespace(
326+
req_ids=["req-A"],
327+
new_block_ids=[None],
328+
resumed_req_ids=set(),
329+
num_computed_tokens=[len(prompt_a)],
330+
),
331+
num_scheduled_tokens={"req-A": 1, "req-B": suffix_len_b},
332+
total_num_scheduled_tokens=1 + suffix_len_b,
333+
finished_req_ids=set(),
334+
preempted_req_ids=set(),
335+
grammar_bitmask=None,
336+
)
337+
338+
with (
339+
patch.object(mr.MetalModelRunner, "_extract_logits", return_value=logits),
340+
patch(
341+
"vllm_metal.v1.model_runner._mlx_greedy_sample",
342+
side_effect=greedy_tokens,
343+
),
344+
patch("vllm_metal.v1.model_runner.prepare_unified"),
345+
patch("vllm_metal.v1.model_runner.clear_context"),
346+
):
347+
runner.execute_model(sched_out)
348+
349+
state_a = runner._request_states["req-A"]
350+
assert state_a.token_ids[-1] == decode_token
351+
assert state_a.generated_tokens == 2
352+
353+
state_b = runner._request_states.get("req-B")
354+
assert state_b is not None
355+
assert state_b.token_ids == prompt_b + [prefill_token]
356+
assert state_b.prompt_len == len(prompt_b)
357+
assert state_b.generated_tokens == 1
358+
assert runner._paged_request_seq_lens.get("req-B") == len(prompt_b)
359+
360+
292361
class TestCachedRequestContinuation:
293362
"""Verify the cached/intermediate-chunk path works with prefix offsets."""
294363

@@ -348,3 +417,87 @@ def test_cached_intermediate_chunk_with_offset(self):
348417
assert state.generated_tokens == len(state.token_ids) - state.prompt_len
349418
# seq_lens must reflect full sequence
350419
assert runner._paged_request_seq_lens["req-1"] == len(prompt)
420+
421+
422+
def _make_paged_ctx_spy(
423+
captured: list,
424+
) -> Callable[[pac.PagedAttentionContext], None]:
425+
def spy(ctx: pac.PagedAttentionContext) -> None:
426+
captured.append(ctx)
427+
pac._thread_local.paged_ctx = ctx
428+
429+
return spy
430+
431+
432+
class TestPrepareUnifiedSlotMapping:
433+
"""Verify prepare_unified is called with correct slot mapping and RoPE offsets.
434+
435+
All other tests in this file patch prepare_unified out. These tests let it
436+
run for real and spy on set_context to confirm the runner passes the right
437+
block_ids, num_tokens, and start_pos arguments so that slot mapping and RoPE
438+
offsets are exercised end-to-end.
439+
"""
440+
441+
def test_fresh_prefill_slot_mapping_and_rope_offset(self):
442+
"""start_pos == 0: slots cover positions 0..N-1, offset is 0."""
443+
runner = _make_paged_runner()
444+
prompt = [10, 20, 30, 40]
445+
block_ids = [0] # block_size=4, block 0 covers positions 0-3
446+
logits = mx.zeros((1, len(prompt), 100))
447+
runner.model.return_value = MagicMock(logits=logits)
448+
449+
captured: list[pac.PagedAttentionContext] = []
450+
451+
new_req = _make_new_req(
452+
"req-1", prompt, num_computed_tokens=0, block_ids=block_ids
453+
)
454+
sched_out = _make_scheduler_output([new_req])
455+
456+
with (
457+
patch.object(mr.MetalModelRunner, "_extract_logits", return_value=logits),
458+
patch(
459+
"vllm_metal.v1.model_runner._mlx_greedy_sample",
460+
return_value=mx.array([0]),
461+
),
462+
patch.object(pac, "set_context", side_effect=_make_paged_ctx_spy(captured)),
463+
):
464+
runner.execute_model(sched_out)
465+
466+
assert len(captured) == 1
467+
ctx = captured[0]
468+
assert ctx.slot_mapping == [0, 1, 2, 3]
469+
assert ctx.offsets == [0]
470+
assert ctx.context_lens == [4]
471+
472+
def test_prefix_hit_slot_mapping_starts_at_start_pos(self):
473+
"""start_pos == 2: slots cover positions 2-3, RoPE offset is 2."""
474+
runner = _make_paged_runner()
475+
prompt = [10, 20, 30, 40]
476+
num_computed = 2
477+
block_ids = [0] # block_size=4, block 0 covers positions 0-3
478+
suffix_len = len(prompt) - num_computed
479+
logits = mx.zeros((1, suffix_len, 100))
480+
runner.model.return_value = MagicMock(logits=logits)
481+
482+
captured: list[pac.PagedAttentionContext] = []
483+
484+
new_req = _make_new_req(
485+
"req-1", prompt, num_computed_tokens=num_computed, block_ids=block_ids
486+
)
487+
sched_out = _make_scheduler_output([new_req])
488+
489+
with (
490+
patch.object(mr.MetalModelRunner, "_extract_logits", return_value=logits),
491+
patch(
492+
"vllm_metal.v1.model_runner._mlx_greedy_sample",
493+
return_value=mx.array([0]),
494+
),
495+
patch.object(pac, "set_context", side_effect=_make_paged_ctx_spy(captured)),
496+
):
497+
runner.execute_model(sched_out)
498+
499+
assert len(captured) == 1
500+
ctx = captured[0]
501+
assert ctx.slot_mapping == [2, 3] # positions 2-3 in block 0
502+
assert ctx.offsets == [2]
503+
assert ctx.context_lens == [4] # start_pos + num_tokens = 2 + 2

vllm_metal/v1/model_runner.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,12 @@ def execute_model(
17311731
for req_id in decode_req_ids:
17321732
state = self._request_states.get(req_id)
17331733
if state is None:
1734+
# Placeholder keeps the output tensor aligned; the warning surfaces the bug.
1735+
logger.warning(
1736+
"Paged cached request %s has no RequestState; "
1737+
"emitting placeholder token. This is a state tracking bug.",
1738+
req_id,
1739+
)
17341740
req_ids.append(req_id)
17351741
req_id_to_index[req_id] = len(req_ids) - 1
17361742
sampled_tokens.append([0])
@@ -1812,14 +1818,22 @@ def execute_model(
18121818
) in paged_prefill_entries:
18131819
# Full prompt for sampling metadata (needed when token_ids
18141820
# is a suffix slice due to prefix cache hit).
1821+
# State exists for cached requests (intermediate chunks);
1822+
# new requests with a prefix hit look up from new_reqs_by_id.
18151823
state = self._request_states.get(rid)
1816-
full_prompt = (
1817-
list(state.token_ids[: state.prompt_len])
1818-
if start_pos > 0 and state is not None
1819-
else list(new_reqs_by_id[rid].prompt_token_ids)
1820-
if start_pos > 0
1821-
else None
1822-
)
1824+
if start_pos == 0:
1825+
full_prompt = None
1826+
elif state is not None:
1827+
full_prompt = state.token_ids[: state.prompt_len]
1828+
else:
1829+
req = new_reqs_by_id.get(rid)
1830+
if req is None:
1831+
raise RuntimeError(
1832+
f"Prefix cache hit (start_pos={start_pos}) for request "
1833+
f"{rid!r} but it has no RequestState and is not in "
1834+
"new_reqs. This is a state tracking bug."
1835+
)
1836+
full_prompt = list(req.prompt_token_ids)
18231837
prefill_pack.append(
18241838
PrefillRequest(
18251839
req_id=rid,
@@ -1856,12 +1870,10 @@ def execute_model(
18561870
sampled_tokens[idx] = []
18571871
elif is_new:
18581872
sampled_tokens[idx] = [nt]
1859-
# When prefix cache hits (start_pos > 0), tids is only
1860-
# the suffix slice. State needs the full prompt.
1861-
if _start_pos > 0:
1862-
full_prompt = list(new_reqs_by_id[rid].prompt_token_ids)
1863-
else:
1864-
full_prompt = list(tids)
1873+
cached_full_prompt = prefill_pack[i].full_prompt_token_ids
1874+
full_prompt = (
1875+
cached_full_prompt if cached_full_prompt is not None else tids
1876+
)
18651877
self._request_states[rid] = RequestState(
18661878
token_ids=full_prompt + [nt],
18671879
prompt_len=_prompt_len,

0 commit comments

Comments
 (0)