2020)
2121from parallax .server .sampling .sampler import SamplingBatchInfo
2222from parallax .server .shard_loader import MLXModelLoader
23+ from parallax .utils .mac_prefill_adder import AddReqResult , MACPrefillAdder
2324from parallax .utils .utils import (
2425 combine_padding_and_causal_masks ,
2526 create_causal_mask ,
@@ -91,6 +92,7 @@ def __init__(
9192 # Weight Refit
9293 enable_weight_refit : Optional [bool ] = False ,
9394 weight_refit_mode : Optional [str ] = "disk" ,
95+ chunked_prefill_size : Optional [int ] = None ,
9496 # Pipe communication
9597 conn : Optional [List [Any ]] = [],
9698 ):
@@ -187,6 +189,15 @@ def __init__(
187189 kv_block_size ,
188190 self .num_shard_layers ,
189191 )
192+ if chunked_prefill_size is not None :
193+ if chunked_prefill_size < 0 :
194+ raise ValueError ("chunked_prefill_size must be non-negative" )
195+ if chunked_prefill_size == 0 :
196+ chunked_prefill_size = None
197+ elif not enable_prefix_cache :
198+ enable_prefix_cache = True
199+ logger .info ("Prefix cache enabled automatically for MLX chunked prefill" )
200+
190201 self .cache_manager = CacheManager (
191202 num_layers = self .num_shard_layers ,
192203 num_kv_heads = num_key_value_heads // tp_size ,
@@ -208,6 +219,18 @@ def __init__(
208219 enable_prefix_cache = enable_prefix_cache ,
209220 sliding_window = sliding_window ,
210221 )
222+
223+ if chunked_prefill_size is not None :
224+ self .chunked_prefill_size = (
225+ (chunked_prefill_size + self .cache_manager .block_size - 1 )
226+ // self .cache_manager .block_size
227+ * self .cache_manager .block_size
228+ )
229+ else :
230+ self .chunked_prefill_size = None
231+ self .cache_manager .chunked_prefill_size = self .chunked_prefill_size
232+ self .cache_manager .defer_prefill_allocation = self .chunked_prefill_size is not None
233+
211234 super ().__init__ (
212235 start_layer = start_layer ,
213236 end_layer = end_layer ,
@@ -235,6 +258,7 @@ def __init__(
235258
236259 # Prefix Cache Manager
237260 self .enable_prefix_cache = enable_prefix_cache
261+ self .chunked_req = None
238262 # self.prefix_cache = RadixCache(
239263 # num_kv_heads=num_key_value_heads,
240264 # head_dim=head_dim,
@@ -273,6 +297,26 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj):
273297 data = pickle .loads (np .array (data_arr ).tobytes ())
274298 return data
275299
300+ def _complete_local_middle_chunk (self , request_id : str ):
301+ """Finish local bookkeeping after this peer runs a non-final prefill chunk."""
302+ if self .chunked_req is None or self .chunked_req .rid != request_id :
303+ return
304+
305+ self .chunked_req .is_chunked = False
306+ self .cache_manager .release_request (request_id )
307+
308+ if self .is_first_peer :
309+ original_req = self .scheduler .get_running_request (request_id )
310+ if original_req is None :
311+ logger .warning (
312+ f"Completed local chunk for { request_id } , but no running request was found."
313+ )
314+ return
315+ original_req .status = RequestStatus .PREFILLING
316+ self .scheduler .enque_request (original_req )
317+ else :
318+ self .scheduler .evict_request (request_id )
319+
276320 def handle_input_requests (self , requests : List [Request ]):
277321 """Update requests states and status in scheduler and cache manager."""
278322 if self .tp_size > 1 :
@@ -341,11 +385,6 @@ def handle_input_requests(self, requests: List[Request]):
341385 req , IntermediateRequest
342386 ), "Non-first peers must receive IntermediateRequests."
343387 if req .is_finished or req .hidden_states is None :
344- if self .enable_prefix_cache :
345- keys , values = self .cache_manager .gather_kv_cache (req .request_id )
346- self .prefix_cache .cache_finished_request (req , keys , values )
347- self .prefix_cache .evict_request (req .request_id )
348-
349388 self .cache_manager .release_request (req .request_id )
350389 logger .debug (
351390 f"Released resources for finished request { req .request_id } , "
@@ -358,6 +397,29 @@ def handle_input_requests(self, requests: List[Request]):
358397 # This is an active request, add it to the scheduler queue to be processed.
359398 self .scheduler .enque_request (req )
360399
400+ def prepare_next_batch_requests (
401+ self , requests : List [Request ], batch_output : Any , context_lengths : Any
402+ ) -> List [Request ]:
403+ next_batch = super ().prepare_next_batch_requests (requests , batch_output , context_lengths )
404+ if (
405+ self .chunked_req is None
406+ or not self .chunked_req .is_chunked
407+ or self .chunked_req .rid not in [req .request_id for req in requests ]
408+ ):
409+ return next_batch
410+
411+ chunked_rid = self .chunked_req .rid
412+ filtered_next_batch = []
413+ for req in next_batch :
414+ if req .request_id == chunked_rid :
415+ req .status = RequestStatus .PREFILLING
416+ if self .is_last_peer :
417+ continue
418+ filtered_next_batch .append (req )
419+
420+ self ._complete_local_middle_chunk (chunked_rid )
421+ return filtered_next_batch
422+
361423 def process_batch (self , prepared_inputs : Dict [str , Any ], return_decoded_tokens : bool = True ):
362424 """Process a batch of requests in MLX."""
363425 # Run model and get updated cache
@@ -380,6 +442,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
380442 prefix_lens = prepared_inputs .get ("prefix_lens" ), # For RoPE offset in prefix cache
381443 )
382444 mx .eval (hidden_states )
445+ if is_prefill_batch and self .chunked_prefill_size is not None :
446+ self .cache_manager .materialize_linear_caches ()
383447
384448 if logger .isEnabledFor (logging .DEBUG ):
385449 forward_time = (time .time () - start_time ) * 1000
@@ -507,6 +571,46 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
507571 if batch_size == 0 :
508572 return None
509573
574+ if self .chunked_prefill_size is not None :
575+ original_batched_requests = batched_requests
576+ adder = MACPrefillAdder (
577+ self .cache_manager .block_size , self .chunked_prefill_size , self .cache_manager
578+ )
579+ chunked_rid = self .chunked_req .rid if self .chunked_req is not None else None
580+
581+ if self .chunked_req is not None and chunked_rid in [
582+ req .request_id for req in original_batched_requests
583+ ]:
584+ self .chunked_req = [
585+ req for req in original_batched_requests if req .request_id == chunked_rid
586+ ][0 ]
587+ self .chunked_req = adder .add_chunked_req (self .chunked_req )
588+
589+ for old_req in original_batched_requests :
590+ if chunked_rid is not None and old_req .request_id == chunked_rid :
591+ continue
592+ res = adder .add_one_req (old_req )
593+ if res != AddReqResult .CONTINUE :
594+ break
595+
596+ if adder .new_chunked_req is not None :
597+ self .chunked_req = adder .new_chunked_req
598+
599+ if self .chunked_req is not None and self .chunked_req .rid in [
600+ req .request_id for req in original_batched_requests
601+ ]:
602+ self .chunked_req .is_chunked = True
603+
604+ can_run_by_id = {req .request_id : req for req in adder .can_run_list }
605+ batched_requests = [
606+ can_run_by_id [req .request_id ]
607+ for req in original_batched_requests
608+ if req .request_id in can_run_by_id
609+ ]
610+ batch_size = len (batched_requests )
611+ if batch_size == 0 :
612+ return None
613+
510614 h_or_tokens_list = []
511615 block_tables_list = []
512616 context_lengths_list = []
@@ -555,14 +659,14 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
555659 actual_processed_lengths_list .append (len (req .input_ids ))
556660 else :
557661 if matched_tokens > 0 and self .enable_prefix_cache :
558- # Skip the prefix hidden states that correspond to cached tokens
559- new_hidden = req .hidden_states [matched_tokens :]
560- if new_hidden .shape [0 ] == 0 :
662+ keep_len = req .total_length - matched_tokens
663+ if keep_len <= 0 :
561664 # All tokens cached - keep the last hidden state
562665 new_hidden = req .hidden_states [- 1 :]
563666 prefix_lens_list [- 1 ] = matched_tokens - 1
564667 actual_processed_lengths_list .append (1 )
565668 else :
669+ new_hidden = req .hidden_states [- keep_len :]
566670 actual_processed_lengths_list .append (new_hidden .shape [0 ])
567671 h_or_tokens_list .append (new_hidden )
568672 else :
0 commit comments