Skip to content

Commit 3888eed

Browse files
authored
feat: add mlx chunked prefill support (#469)
1 parent 4e1f125 commit 3888eed

16 files changed

Lines changed: 573 additions & 22 deletions

src/parallax/server/cache/linear_cache.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Tuple
1+
from typing import List, Optional, Tuple
22

33
import mlx.core as mx
44

@@ -52,6 +52,14 @@ def __init__(
5252
def get_cache(self) -> Tuple[Optional[mx.array], Optional[mx.array]]:
5353
return self.conv_state_cache, self.linear_state_cache
5454

55+
def get_state_cache_arrays(self) -> List[mx.array]:
56+
arrays = []
57+
if self.conv_state_cache is not None:
58+
arrays.append(self.conv_state_cache)
59+
if self.linear_state_cache is not None:
60+
arrays.append(self.linear_state_cache)
61+
return arrays
62+
5563
def get_indexer_cache(self) -> Optional[mx.array]:
5664
return None
5765

src/parallax/server/cache_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,14 @@ def get_caches(self) -> List[BaseCache]:
446446
"""Returns the list of layer caches."""
447447
return self.caches
448448

449+
def materialize_linear_caches(self):
450+
arrays = []
451+
for cache in self.caches:
452+
if isinstance(cache, LinearCache):
453+
arrays.extend(cache.get_state_cache_arrays())
454+
if arrays:
455+
mx.eval(*arrays)
456+
449457
def match_and_reuse_prefix(self, request_id: str, token_ids: List[int]) -> int:
450458
"""
451459
Match prefix before prefill and reuse existing blocks.

src/parallax/server/executor/base_executor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -741,8 +741,11 @@ def run_loop(self):
741741
# 8. Dispatch to the appropriate destination
742742
if self.is_last_peer and self.is_first_peer:
743743
# Single node: handle locally
744-
self.handle_input_requests(next_batch)
744+
if next_batch:
745+
self.handle_input_requests(next_batch)
745746
elif self.tp_rank == 0:
747+
if not next_batch:
748+
continue
746749
# Send output to next peer
747750
self.send_to_peer_socket.send_multipart(
748751
[
@@ -833,7 +836,7 @@ def _prepare_next_single_request(
833836
request_id=request.request_id,
834837
status=RequestStatus.DECODING,
835838
current_position=request.total_length + 1,
836-
input_ids=request.input_ids,
839+
input_ids=request.origin_input_ids,
837840
hidden_states=hidden_states,
838841
next_token_id=next_token_id,
839842
routing_table=request.routing_table,
@@ -852,7 +855,7 @@ def _prepare_next_single_request(
852855
request_id=request.request_id,
853856
status=RequestStatus.DECODING, # Last peer always changes status to DECODING
854857
current_position=request.total_length,
855-
input_ids=request.input_ids,
858+
input_ids=request.origin_input_ids,
856859
hidden_states=hidden_states,
857860
next_token_id=next_token_id,
858861
routing_table=request.routing_table,

src/parallax/server/executor/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def create_executor_config(args: argparse.Namespace, shared_state=None, conn=Non
5050
"max_loaded_loras": args.max_loaded_loras,
5151
"enable_weight_refit": args.enable_weight_refit,
5252
"weight_refit_mode": args.weight_refit_mode,
53+
"chunked_prefill_size": getattr(args, "chunked_prefill_size", None),
5354
}
5455

5556
if args.gpu_backend == "sglang":

src/parallax/server/executor/mlx_executor.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from parallax.server.sampling.sampler import SamplingBatchInfo
2222
from parallax.server.shard_loader import MLXModelLoader
23+
from parallax.utils.mac_prefill_adder import AddReqResult, MACPrefillAdder
2324
from parallax.utils.utils import (
2425
combine_padding_and_causal_masks,
2526
create_causal_mask,
@@ -91,6 +92,7 @@ def __init__(
9192
# Weight Refit
9293
enable_weight_refit: Optional[bool] = False,
9394
weight_refit_mode: Optional[str] = "disk",
95+
chunked_prefill_size: Optional[int] = None,
9496
# Pipe communication
9597
conn: Optional[List[Any]] = [],
9698
):
@@ -187,6 +189,15 @@ def __init__(
187189
kv_block_size,
188190
self.num_shard_layers,
189191
)
192+
if chunked_prefill_size is not None:
193+
if chunked_prefill_size < 0:
194+
raise ValueError("chunked_prefill_size must be non-negative")
195+
if chunked_prefill_size == 0:
196+
chunked_prefill_size = None
197+
elif not enable_prefix_cache:
198+
enable_prefix_cache = True
199+
logger.info("Prefix cache enabled automatically for MLX chunked prefill")
200+
190201
self.cache_manager = CacheManager(
191202
num_layers=self.num_shard_layers,
192203
num_kv_heads=num_key_value_heads // tp_size,
@@ -208,6 +219,18 @@ def __init__(
208219
enable_prefix_cache=enable_prefix_cache,
209220
sliding_window=sliding_window,
210221
)
222+
223+
if chunked_prefill_size is not None:
224+
self.chunked_prefill_size = (
225+
(chunked_prefill_size + self.cache_manager.block_size - 1)
226+
// self.cache_manager.block_size
227+
* self.cache_manager.block_size
228+
)
229+
else:
230+
self.chunked_prefill_size = None
231+
self.cache_manager.chunked_prefill_size = self.chunked_prefill_size
232+
self.cache_manager.defer_prefill_allocation = self.chunked_prefill_size is not None
233+
211234
super().__init__(
212235
start_layer=start_layer,
213236
end_layer=end_layer,
@@ -235,6 +258,7 @@ def __init__(
235258

236259
# Prefix Cache Manager
237260
self.enable_prefix_cache = enable_prefix_cache
261+
self.chunked_req = None
238262
# self.prefix_cache = RadixCache(
239263
# num_kv_heads=num_key_value_heads,
240264
# head_dim=head_dim,
@@ -273,6 +297,26 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj):
273297
data = pickle.loads(np.array(data_arr).tobytes())
274298
return data
275299

300+
def _complete_local_middle_chunk(self, request_id: str):
301+
"""Finish local bookkeeping after this peer runs a non-final prefill chunk."""
302+
if self.chunked_req is None or self.chunked_req.rid != request_id:
303+
return
304+
305+
self.chunked_req.is_chunked = False
306+
self.cache_manager.release_request(request_id)
307+
308+
if self.is_first_peer:
309+
original_req = self.scheduler.get_running_request(request_id)
310+
if original_req is None:
311+
logger.warning(
312+
f"Completed local chunk for {request_id}, but no running request was found."
313+
)
314+
return
315+
original_req.status = RequestStatus.PREFILLING
316+
self.scheduler.enque_request(original_req)
317+
else:
318+
self.scheduler.evict_request(request_id)
319+
276320
def handle_input_requests(self, requests: List[Request]):
277321
"""Update requests states and status in scheduler and cache manager."""
278322
if self.tp_size > 1:
@@ -341,11 +385,6 @@ def handle_input_requests(self, requests: List[Request]):
341385
req, IntermediateRequest
342386
), "Non-first peers must receive IntermediateRequests."
343387
if req.is_finished or req.hidden_states is None:
344-
if self.enable_prefix_cache:
345-
keys, values = self.cache_manager.gather_kv_cache(req.request_id)
346-
self.prefix_cache.cache_finished_request(req, keys, values)
347-
self.prefix_cache.evict_request(req.request_id)
348-
349388
self.cache_manager.release_request(req.request_id)
350389
logger.debug(
351390
f"Released resources for finished request {req.request_id}, "
@@ -358,6 +397,29 @@ def handle_input_requests(self, requests: List[Request]):
358397
# This is an active request, add it to the scheduler queue to be processed.
359398
self.scheduler.enque_request(req)
360399

400+
def prepare_next_batch_requests(
401+
self, requests: List[Request], batch_output: Any, context_lengths: Any
402+
) -> List[Request]:
403+
next_batch = super().prepare_next_batch_requests(requests, batch_output, context_lengths)
404+
if (
405+
self.chunked_req is None
406+
or not self.chunked_req.is_chunked
407+
or self.chunked_req.rid not in [req.request_id for req in requests]
408+
):
409+
return next_batch
410+
411+
chunked_rid = self.chunked_req.rid
412+
filtered_next_batch = []
413+
for req in next_batch:
414+
if req.request_id == chunked_rid:
415+
req.status = RequestStatus.PREFILLING
416+
if self.is_last_peer:
417+
continue
418+
filtered_next_batch.append(req)
419+
420+
self._complete_local_middle_chunk(chunked_rid)
421+
return filtered_next_batch
422+
361423
def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True):
362424
"""Process a batch of requests in MLX."""
363425
# Run model and get updated cache
@@ -380,6 +442,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
380442
prefix_lens=prepared_inputs.get("prefix_lens"), # For RoPE offset in prefix cache
381443
)
382444
mx.eval(hidden_states)
445+
if is_prefill_batch and self.chunked_prefill_size is not None:
446+
self.cache_manager.materialize_linear_caches()
383447

384448
if logger.isEnabledFor(logging.DEBUG):
385449
forward_time = (time.time() - start_time) * 1000
@@ -507,6 +571,46 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
507571
if batch_size == 0:
508572
return None
509573

574+
if self.chunked_prefill_size is not None:
575+
original_batched_requests = batched_requests
576+
adder = MACPrefillAdder(
577+
self.cache_manager.block_size, self.chunked_prefill_size, self.cache_manager
578+
)
579+
chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None
580+
581+
if self.chunked_req is not None and chunked_rid in [
582+
req.request_id for req in original_batched_requests
583+
]:
584+
self.chunked_req = [
585+
req for req in original_batched_requests if req.request_id == chunked_rid
586+
][0]
587+
self.chunked_req = adder.add_chunked_req(self.chunked_req)
588+
589+
for old_req in original_batched_requests:
590+
if chunked_rid is not None and old_req.request_id == chunked_rid:
591+
continue
592+
res = adder.add_one_req(old_req)
593+
if res != AddReqResult.CONTINUE:
594+
break
595+
596+
if adder.new_chunked_req is not None:
597+
self.chunked_req = adder.new_chunked_req
598+
599+
if self.chunked_req is not None and self.chunked_req.rid in [
600+
req.request_id for req in original_batched_requests
601+
]:
602+
self.chunked_req.is_chunked = True
603+
604+
can_run_by_id = {req.request_id: req for req in adder.can_run_list}
605+
batched_requests = [
606+
can_run_by_id[req.request_id]
607+
for req in original_batched_requests
608+
if req.request_id in can_run_by_id
609+
]
610+
batch_size = len(batched_requests)
611+
if batch_size == 0:
612+
return None
613+
510614
h_or_tokens_list = []
511615
block_tables_list = []
512616
context_lengths_list = []
@@ -555,14 +659,14 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
555659
actual_processed_lengths_list.append(len(req.input_ids))
556660
else:
557661
if matched_tokens > 0 and self.enable_prefix_cache:
558-
# Skip the prefix hidden states that correspond to cached tokens
559-
new_hidden = req.hidden_states[matched_tokens:]
560-
if new_hidden.shape[0] == 0:
662+
keep_len = req.total_length - matched_tokens
663+
if keep_len <= 0:
561664
# All tokens cached - keep the last hidden state
562665
new_hidden = req.hidden_states[-1:]
563666
prefix_lens_list[-1] = matched_tokens - 1
564667
actual_processed_lengths_list.append(1)
565668
else:
669+
new_hidden = req.hidden_states[-keep_len:]
566670
actual_processed_lengths_list.append(new_hidden.shape[0])
567671
h_or_tokens_list.append(new_hidden)
568672
else:

src/parallax/server/executor/sglang_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
# Weight Refit
9090
enable_weight_refit: Optional[bool] = False,
9191
weight_refit_mode: Optional[str] = "disk",
92+
chunked_prefill_size: Optional[int] = None,
9293
# Pipe communication
9394
conn: Optional[List[Any]] = [],
9495
):

src/parallax/server/executor/vllm_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
# Weight Refit
8989
enable_weight_refit: Optional[bool] = False,
9090
weight_refit_mode: Optional[str] = "disk",
91+
chunked_prefill_size: Optional[int] = None,
9192
# Routed experts
9293
enable_return_routed_experts: bool = False,
9394
# Pipe communication

src/parallax/server/request.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,16 @@ def __init__(
109109
self.last_updated_time: Optional[float] = None
110110
self.lora_id: Optional[str] = None
111111
self.lora_path = lora_path
112+
self.rid = self.request_id
113+
self.is_chunked = False
114+
# Full prompt tokens as received from the client. Chunked prefill may
115+
# temporarily shorten input_ids to the prefix visible to the current
116+
# chunk, so keep this immutable source for later chunks and decode.
117+
self.origin_input_ids = list(input_ids) if input_ids is not None else None
118+
# Optional logical sequence length override for chunked prefill. When
119+
# set, total_length reports the end offset of the current prompt chunk
120+
# instead of the final prompt+output length.
121+
self._effective_total_length: Optional[int] = None
112122

113123
@property
114124
def is_finished(self) -> bool:
@@ -201,7 +211,9 @@ def output_length(self) -> int:
201211

202212
@property
203213
def total_length(self) -> int:
204-
"""Total length of the sequence (input + output)."""
214+
"""Logical sequence length for scheduling/cache allocation."""
215+
if self._effective_total_length is not None:
216+
return self._effective_total_length
205217
return self.prompt_len + self.output_length
206218

207219
def get_model_input_for_first_peer(self) -> List[int]:
@@ -226,6 +238,10 @@ def commit_new_token(self, token_id: int):
226238
)
227239
return
228240

241+
self._effective_total_length = None
242+
if self.origin_input_ids is not None:
243+
self.input_ids = self.origin_input_ids
244+
229245
self.output_ids.append(token_id)
230246

231247
# Finishing condition checks are now handled by the Scheduler.
@@ -302,7 +318,9 @@ def input_length(self) -> int:
302318

303319
@property
304320
def total_length(self) -> int:
305-
"""Total length of the sequence (input + output)."""
321+
"""Logical sequence length for scheduling/cache allocation."""
322+
if self._effective_total_length is not None:
323+
return self._effective_total_length
306324
return self.current_position
307325

308326
@classmethod
@@ -337,7 +355,7 @@ def from_initial_request(
337355
return IntermediateRequest(
338356
request_id=initial_request.request_id,
339357
status=initial_request.status,
340-
input_ids=initial_request.input_ids,
358+
input_ids=initial_request.origin_input_ids,
341359
next_token_id=next_token_id,
342360
current_position=initial_request.total_length,
343361
hidden_states=hidden_states,
@@ -364,7 +382,7 @@ def from_intermediate_request(
364382
request_id=old_request.request_id,
365383
status=old_request.status,
366384
current_position=old_request.total_length,
367-
input_ids=old_request.input_ids,
385+
input_ids=old_request.origin_input_ids,
368386
next_token_id=old_request.next_token_id,
369387
hidden_states=new_hidden_states,
370388
routing_table=old_request.routing_table,

0 commit comments

Comments
 (0)