diff --git a/tests/test_batching.py b/tests/test_batching.py index 7dc050ee..1b1191fd 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -7,8 +7,10 @@ """ import asyncio +import importlib import pytest from unittest.mock import MagicMock +import mlx.core as mx from vllm_mlx.request import ( Request, @@ -20,8 +22,11 @@ Scheduler, SchedulerConfig, SchedulingPolicy, + _install_chunked_prefill, ) +mlx_generate = importlib.import_module("mlx_lm.generate") + class TestRequest: """Tests for Request class.""" @@ -211,6 +216,168 @@ def mock_model(self): """Create a mock model.""" return MagicMock() + def test_chunked_prefill_accepts_prompt_checkpoints(self, monkeypatch): + """Chunked prefill must match mlx-lm's 7-field prompt tuples.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + + class FakeBatchGenerator: + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3, 4, 5], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = set() + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = None + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=4) + + responses = batch_gen._next() + + assert responses == [] + assert batch_gen._partial is not None + assert batch_gen._partial["prompt_checkpoint"] == 3 + assert batch_gen._partial["processed"] == 2 + + def test_chunked_prefill_invokes_checkpoint_callback(self, monkeypatch): + """prompt_checkpoint_callback must fire after finalization.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + callback_payloads = [] + + from collections import namedtuple + _Response = namedtuple("Response", ["uid", "token", "logprobs", "finish_reason", "cache"]) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = set() + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = ( + lambda entries: callback_payloads.extend(entries) + ) + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + return mx.array([99]), mx.array([-1.0]) + + def _generation_step(self): + if self.active_batch is not None: + self.active_batch = None + return [] + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + batch_gen.stop_tokens = {99} + _install_chunked_prefill(batch_gen, budget=1) + + # First _next: starts partial prefill (processes 1 token) + batch_gen._next() + assert batch_gen._partial is not None + + # Second _next: finishes prefill, fires checkpoint callback, + # then runs generation step which completes (stop token). + batch_gen._next() + + assert len(callback_payloads) == 1 + uid, checkpoint, _cache_gen = callback_payloads[0] + assert uid == 7 + assert checkpoint == 1 + def test_scheduler_creation(self, mock_model, mock_tokenizer): """Test scheduler creation.""" scheduler = Scheduler( diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 88d144cb..76cb0295 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -148,6 +148,7 @@ def _install_chunked_prefill( from mlx_lm.generate import ( Batch, + _lazy_extract_cache, _left_pad_prompts, _make_cache, _merge_caches, @@ -267,8 +268,13 @@ def _chunked_next(self=batch_gen): # noqa: C901 inputs = partial["inputs"] prompt_cache = partial["cache"] remaining = inputs.shape[1] + prompt_checkpoint = max(1, int(partial.get("prompt_checkpoint", 1))) - n_to_process = min(budget, remaining - 1) if remaining > 1 else 0 + n_to_process = ( + min(budget, remaining - prompt_checkpoint) + if remaining > prompt_checkpoint + else 0 + ) if n_to_process > 0: self.model(mx.contiguous(inputs[:, :n_to_process]), cache=prompt_cache) @@ -293,8 +299,8 @@ def _chunked_next(self=batch_gen): # noqa: C901 if partial.get("is_cached"): mx.clear_cache() - # Check if prefill is done (only 1 token left or 0) - if inputs.shape[1] <= 1: + # Check if prefill is done once only the checkpoint tail remains. + if inputs.shape[1] <= prompt_checkpoint: # Finalize if partial.get("is_cached"): mx.eval([c.state for c in prompt_cache]) @@ -302,8 +308,24 @@ def _chunked_next(self=batch_gen): # noqa: C901 for c in prompt_cache: c.finalize() + + if self.prompt_checkpoint_callback is not None: + self.prompt_checkpoint_callback( + [ + (uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i)) + for i, uid in enumerate(partial["uids"]) + ] + ) mx.clear_cache() + if prompt_checkpoint > 1: + self.model( + mx.contiguous(inputs[:, : prompt_checkpoint - 1]), + cache=prompt_cache, + ) + mx.eval([c.state for c in prompt_cache]) + mx.clear_cache() + y, logprobs = self._step( inputs, prompt_cache, @@ -392,25 +414,37 @@ def _chunked_next(self=batch_gen): # noqa: C901 caches, samplers, logits_processors, + prompt_checkpoints, ) = zip(*batch_prompts) lengths = [len(p) for p in inputs_raw] max_length = max(lengths) padding = [max_length - ln for ln in lengths] tokens = [mx.array(inp) for inp in inputs_raw] + checkpoint_offsets = [ + (ln - pc if pc > 0 else -pc) + for ln, pc in zip(lengths, prompt_checkpoints) + ] + prompt_checkpoint = max(1, max(checkpoint_offsets)) is_cached = not all(c[0].empty() for c in caches) self._stats.prompt_tokens += sum(lengths) if not is_cached: padded = _left_pad_prompts(inputs_raw, max_length=max_length) - prompt_cache = _make_cache(self.model, padding) + prompt_cache = _make_cache( + self.model, padding, self.max_kv_size + ) else: - last_inputs = mx.array([p[-1:] for p in inputs_raw]) + last_inputs = mx.array( + [p[-prompt_checkpoint:] for p in inputs_raw] + ) padded = _right_pad_prompts(inputs_raw, max_length=max_length) prompt_cache = _merge_caches(caches) for c in prompt_cache: c.prepare( - lengths=[ln - 1 for ln in lengths], + lengths=[ + ln - prompt_checkpoint for ln in lengths + ], right_padding=padding, ) @@ -433,9 +467,9 @@ def _chunked_next(self=batch_gen): # noqa: C901 _pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0 _cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0 _adjusted_pb = _pb - _cached - if 0 < _adjusted_pb < padded.shape[1]: + if 0 < _adjusted_pb < padded.shape[1] - prompt_checkpoint + 1: _first_chunk = _adjusted_pb - n_to_process = min(_first_chunk, padded.shape[1] - 1) + n_to_process = min(_first_chunk, padded.shape[1] - prompt_checkpoint) if n_to_process > 0: self.model( mx.contiguous(padded[:, :n_to_process]), @@ -454,6 +488,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 "max_tokens": list(max_tokens_list), "samplers": list(samplers), "logits_processors": list(logits_processors), + "prompt_checkpoint": prompt_checkpoint, "processed": n_to_process, "total": max_length, "is_cached": is_cached,