|
20 | 20 | ) |
21 | 21 | from parallax.server.sampling.sampler import SamplingBatchInfo |
22 | 22 | from parallax.server.shard_loader import MLXModelLoader |
| 23 | +from parallax.utils.chunked_prefill import ( |
| 24 | + complete_local_middle_chunk, |
| 25 | + filter_middle_chunk_next_batch, |
| 26 | +) |
23 | 27 | from parallax.utils.mac_prefill_adder import AddReqResult, MACPrefillAdder |
24 | 28 | from parallax.utils.utils import ( |
25 | 29 | combine_padding_and_causal_masks, |
@@ -253,6 +257,7 @@ def __init__( |
253 | 257 | shared_state=shared_state, |
254 | 258 | enable_weight_refit=enable_weight_refit, |
255 | 259 | weight_refit_mode=weight_refit_mode, |
| 260 | + chunked_prefill_size=self.chunked_prefill_size, |
256 | 261 | conn=conn, |
257 | 262 | ) |
258 | 263 |
|
@@ -299,23 +304,11 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj): |
299 | 304 |
|
300 | 305 | def _complete_local_middle_chunk(self, request_id: str): |
301 | 306 | """Finish local bookkeeping after this peer runs a non-final prefill chunk.""" |
302 | | - if self.chunked_req is None or self.chunked_req.rid != request_id: |
303 | | - return |
304 | | - |
305 | | - self.chunked_req.is_chunked = False |
306 | | - self.cache_manager.release_request(request_id) |
307 | | - |
308 | | - if self.is_first_peer: |
309 | | - original_req = self.scheduler.get_running_request(request_id) |
310 | | - if original_req is None: |
311 | | - logger.warning( |
312 | | - f"Completed local chunk for {request_id}, but no running request was found." |
313 | | - ) |
314 | | - return |
315 | | - original_req.status = RequestStatus.PREFILLING |
316 | | - self.scheduler.enque_request(original_req) |
317 | | - else: |
318 | | - self.scheduler.evict_request(request_id) |
| 307 | + complete_local_middle_chunk( |
| 308 | + self, |
| 309 | + request_id, |
| 310 | + release_callback=self.cache_manager.release_request, |
| 311 | + ) |
319 | 312 |
|
320 | 313 | def handle_input_requests(self, requests: List[Request]): |
321 | 314 | """Update requests states and status in scheduler and cache manager.""" |
@@ -401,24 +394,12 @@ def prepare_next_batch_requests( |
401 | 394 | self, requests: List[Request], batch_output: Any, context_lengths: Any |
402 | 395 | ) -> List[Request]: |
403 | 396 | next_batch = super().prepare_next_batch_requests(requests, batch_output, context_lengths) |
404 | | - if ( |
405 | | - self.chunked_req is None |
406 | | - or not self.chunked_req.is_chunked |
407 | | - or self.chunked_req.rid not in [req.request_id for req in requests] |
408 | | - ): |
409 | | - return next_batch |
410 | | - |
411 | | - chunked_rid = self.chunked_req.rid |
412 | | - filtered_next_batch = [] |
413 | | - for req in next_batch: |
414 | | - if req.request_id == chunked_rid: |
415 | | - req.status = RequestStatus.PREFILLING |
416 | | - if self.is_last_peer: |
417 | | - continue |
418 | | - filtered_next_batch.append(req) |
419 | | - |
420 | | - self._complete_local_middle_chunk(chunked_rid) |
421 | | - return filtered_next_batch |
| 397 | + return filter_middle_chunk_next_batch( |
| 398 | + self, |
| 399 | + requests, |
| 400 | + next_batch, |
| 401 | + release_callback=self.cache_manager.release_request, |
| 402 | + ) |
422 | 403 |
|
423 | 404 | def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True): |
424 | 405 | """Process a batch of requests in MLX.""" |
|
0 commit comments