Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
"""

import asyncio
import importlib
import pytest
from unittest.mock import MagicMock
import mlx.core as mx

from vllm_mlx.request import (
Request,
Expand All @@ -20,8 +22,11 @@
Scheduler,
SchedulerConfig,
SchedulingPolicy,
_install_chunked_prefill,
)

mlx_generate = importlib.import_module("mlx_lm.generate")


class TestRequest:
"""Tests for Request class."""
Expand Down Expand Up @@ -211,6 +216,168 @@ def mock_model(self):
"""Create a mock model."""
return MagicMock()

def test_chunked_prefill_accepts_prompt_checkpoints(self, monkeypatch):
"""Chunked prefill must match mlx-lm's 7-field prompt tuples."""

class FakeCacheEntry:
def empty(self):
return True

class FakePromptCache:
def __init__(self):
self.state = mx.array([0])

def finalize(self):
return None

class FakeStats:
prompt_tokens = 0
prompt_time = 0.0
generation_time = 0.0

class FakeBatchGenerator:
def __init__(self):
self._stats = FakeStats()
self._partial = None
self.active_batch = None
self.unprocessed_prompts = [
(
7,
[1, 2, 3, 4, 5],
16,
[FakeCacheEntry()],
None,
[None],
2,
)
]
self.prefill_batch_size = 1
self.completion_batch_size = 1
self.max_kv_size = None
self.stop_tokens = set()
self.prompt_progress_callback = lambda _progress: None
self.prompt_checkpoint_callback = None
self._next = lambda: []
self.remove = lambda _uids: None
self._process_prompts = lambda _prompts: None
self.model = lambda _inputs, cache=None: None

monkeypatch.setattr(
mlx_generate,
"_left_pad_prompts",
lambda prompts, max_length=None: mx.array(prompts),
)
monkeypatch.setattr(
mlx_generate,
"_make_cache",
lambda _model, _padding, _max_kv_size=None: [FakePromptCache()],
)

batch_gen = FakeBatchGenerator()
_install_chunked_prefill(batch_gen, budget=4)

responses = batch_gen._next()

assert responses == []
assert batch_gen._partial is not None
assert batch_gen._partial["prompt_checkpoint"] == 3
assert batch_gen._partial["processed"] == 2

def test_chunked_prefill_invokes_checkpoint_callback(self, monkeypatch):
"""prompt_checkpoint_callback must fire after finalization."""

class FakeCacheEntry:
def empty(self):
return True

class FakePromptCache:
def __init__(self):
self.state = mx.array([0])

def finalize(self):
return None

def extract(self, idx):
return self

class FakeStats:
prompt_tokens = 0
prompt_time = 0.0
generation_time = 0.0
generation_tokens = 0

callback_payloads = []

from collections import namedtuple
_Response = namedtuple("Response", ["uid", "token", "logprobs", "finish_reason", "cache"])

class FakeBatchGenerator:
Response = _Response

def __init__(self):
self._stats = FakeStats()
self._partial = None
self.active_batch = None
self.unprocessed_prompts = [
(
7,
[1, 2, 3],
16,
[FakeCacheEntry()],
None,
[None],
2,
)
]
self.prefill_batch_size = 1
self.completion_batch_size = 1
self.max_kv_size = None
self.stop_tokens = set()
self.prompt_progress_callback = lambda _progress: None
self.prompt_checkpoint_callback = (
lambda entries: callback_payloads.extend(entries)
)
self._next = lambda: []
self.remove = lambda _uids: None
self._process_prompts = lambda _prompts: None
self.model = lambda _inputs, cache=None: None

def _step(self, inputs, cache, samplers, logits_processors, tokens):
return mx.array([99]), mx.array([-1.0])

def _generation_step(self):
if self.active_batch is not None:
self.active_batch = None
return []

monkeypatch.setattr(
mlx_generate,
"_left_pad_prompts",
lambda prompts, max_length=None: mx.array(prompts),
)
monkeypatch.setattr(
mlx_generate,
"_make_cache",
lambda _model, _padding, _max_kv_size=None: [FakePromptCache()],
)

batch_gen = FakeBatchGenerator()
batch_gen.stop_tokens = {99}
_install_chunked_prefill(batch_gen, budget=1)

# First _next: starts partial prefill (processes 1 token)
batch_gen._next()
assert batch_gen._partial is not None

# Second _next: finishes prefill, fires checkpoint callback,
# then runs generation step which completes (stop token).
batch_gen._next()

assert len(callback_payloads) == 1
uid, checkpoint, _cache_gen = callback_payloads[0]
assert uid == 7
assert checkpoint == 1

def test_scheduler_creation(self, mock_model, mock_tokenizer):
"""Test scheduler creation."""
scheduler = Scheduler(
Expand Down
51 changes: 43 additions & 8 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _install_chunked_prefill(

from mlx_lm.generate import (
Batch,
_lazy_extract_cache,
_left_pad_prompts,
_make_cache,
_merge_caches,
Expand Down Expand Up @@ -267,8 +268,13 @@ def _chunked_next(self=batch_gen): # noqa: C901
inputs = partial["inputs"]
prompt_cache = partial["cache"]
remaining = inputs.shape[1]
prompt_checkpoint = max(1, int(partial.get("prompt_checkpoint", 1)))

n_to_process = min(budget, remaining - 1) if remaining > 1 else 0
n_to_process = (
min(budget, remaining - prompt_checkpoint)
if remaining > prompt_checkpoint
else 0
)

if n_to_process > 0:
self.model(mx.contiguous(inputs[:, :n_to_process]), cache=prompt_cache)
Expand All @@ -293,17 +299,33 @@ def _chunked_next(self=batch_gen): # noqa: C901
if partial.get("is_cached"):
mx.clear_cache()

# Check if prefill is done (only 1 token left or 0)
if inputs.shape[1] <= 1:
# Check if prefill is done once only the checkpoint tail remains.
if inputs.shape[1] <= prompt_checkpoint:
# Finalize
if partial.get("is_cached"):
mx.eval([c.state for c in prompt_cache])
inputs = partial["last_inputs"]

for c in prompt_cache:
c.finalize()

if self.prompt_checkpoint_callback is not None:
self.prompt_checkpoint_callback(
[
(uid, prompt_checkpoint, _lazy_extract_cache(prompt_cache, i))
for i, uid in enumerate(partial["uids"])
]
)
mx.clear_cache()

if prompt_checkpoint > 1:
self.model(
mx.contiguous(inputs[:, : prompt_checkpoint - 1]),
cache=prompt_cache,
)
mx.eval([c.state for c in prompt_cache])
mx.clear_cache()

y, logprobs = self._step(
inputs,
prompt_cache,
Expand Down Expand Up @@ -392,25 +414,37 @@ def _chunked_next(self=batch_gen): # noqa: C901
caches,
samplers,
logits_processors,
prompt_checkpoints,
) = zip(*batch_prompts)
lengths = [len(p) for p in inputs_raw]
max_length = max(lengths)
padding = [max_length - ln for ln in lengths]
tokens = [mx.array(inp) for inp in inputs_raw]
checkpoint_offsets = [
(ln - pc if pc > 0 else -pc)
for ln, pc in zip(lengths, prompt_checkpoints)
]
prompt_checkpoint = max(1, max(checkpoint_offsets))
is_cached = not all(c[0].empty() for c in caches)

self._stats.prompt_tokens += sum(lengths)

if not is_cached:
padded = _left_pad_prompts(inputs_raw, max_length=max_length)
prompt_cache = _make_cache(self.model, padding)
prompt_cache = _make_cache(
self.model, padding, self.max_kv_size
)
else:
last_inputs = mx.array([p[-1:] for p in inputs_raw])
last_inputs = mx.array(
[p[-prompt_checkpoint:] for p in inputs_raw]
)
padded = _right_pad_prompts(inputs_raw, max_length=max_length)
prompt_cache = _merge_caches(caches)
for c in prompt_cache:
c.prepare(
lengths=[ln - 1 for ln in lengths],
lengths=[
ln - prompt_checkpoint for ln in lengths
],
right_padding=padding,
)

Expand All @@ -433,9 +467,9 @@ def _chunked_next(self=batch_gen): # noqa: C901
_pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0
_cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0
_adjusted_pb = _pb - _cached
if 0 < _adjusted_pb < padded.shape[1]:
if 0 < _adjusted_pb < padded.shape[1] - prompt_checkpoint + 1:
_first_chunk = _adjusted_pb
n_to_process = min(_first_chunk, padded.shape[1] - 1)
n_to_process = min(_first_chunk, padded.shape[1] - prompt_checkpoint)
if n_to_process > 0:
self.model(
mx.contiguous(padded[:, :n_to_process]),
Expand All @@ -454,6 +488,7 @@ def _chunked_next(self=batch_gen): # noqa: C901
"max_tokens": list(max_tokens_list),
"samplers": list(samplers),
"logits_processors": list(logits_processors),
"prompt_checkpoint": prompt_checkpoint,
"processed": n_to_process,
"total": max_length,
"is_cached": is_cached,
Expand Down