Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
# Weight Refit
enable_weight_refit: Optional[bool] = False,
weight_refit_mode: Optional[str] = "disk",
chunked_prefill_size: Optional[int] = None,
# Pipe communication
conn: Optional[List[Any]] = [],
):
Expand All @@ -108,6 +109,7 @@ def __init__(
self.finished_batch = []
self.start_layer = start_layer
self.end_layer = end_layer
self.max_num_tokens_per_batch = max_num_tokens_per_batch
self._should_stop = False # Flag to gracefully stop the executor
# Reference to shared state for layer reallocation detection (when in subprocess mode)
if shared_state is not None:
Expand Down Expand Up @@ -181,6 +183,7 @@ def __init__(
cache_manager=self.cache_manager if self.device == "mlx" else None,
request_timeout_s=request_timeout_s,
shared_state=self.shared_state,
chunked_prefill_size=chunked_prefill_size,
)
logger.debug(
f"Scheduler initialized (max_batch_size={max_batch_size}, max_tokens={max_num_tokens_per_batch}, wait_ms={scheduler_wait_ms})"
Expand Down
51 changes: 16 additions & 35 deletions src/parallax/server/executor/mlx_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
)
from parallax.server.sampling.sampler import SamplingBatchInfo
from parallax.server.shard_loader import MLXModelLoader
from parallax.utils.chunked_prefill import (
complete_local_middle_chunk,
filter_middle_chunk_next_batch,
)
from parallax.utils.mac_prefill_adder import AddReqResult, MACPrefillAdder
from parallax.utils.utils import (
combine_padding_and_causal_masks,
Expand Down Expand Up @@ -253,6 +257,7 @@ def __init__(
shared_state=shared_state,
enable_weight_refit=enable_weight_refit,
weight_refit_mode=weight_refit_mode,
chunked_prefill_size=self.chunked_prefill_size,
conn=conn,
)

Expand Down Expand Up @@ -299,23 +304,11 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj):

def _complete_local_middle_chunk(self, request_id: str):
"""Finish local bookkeeping after this peer runs a non-final prefill chunk."""
if self.chunked_req is None or self.chunked_req.rid != request_id:
return

self.chunked_req.is_chunked = False
self.cache_manager.release_request(request_id)

if self.is_first_peer:
original_req = self.scheduler.get_running_request(request_id)
if original_req is None:
logger.warning(
f"Completed local chunk for {request_id}, but no running request was found."
)
return
original_req.status = RequestStatus.PREFILLING
self.scheduler.enque_request(original_req)
else:
self.scheduler.evict_request(request_id)
complete_local_middle_chunk(
self,
request_id,
release_callback=self.cache_manager.release_request,
)

def handle_input_requests(self, requests: List[Request]):
"""Update requests states and status in scheduler and cache manager."""
Expand Down Expand Up @@ -401,24 +394,12 @@ def prepare_next_batch_requests(
self, requests: List[Request], batch_output: Any, context_lengths: Any
) -> List[Request]:
next_batch = super().prepare_next_batch_requests(requests, batch_output, context_lengths)
if (
self.chunked_req is None
or not self.chunked_req.is_chunked
or self.chunked_req.rid not in [req.request_id for req in requests]
):
return next_batch

chunked_rid = self.chunked_req.rid
filtered_next_batch = []
for req in next_batch:
if req.request_id == chunked_rid:
req.status = RequestStatus.PREFILLING
if self.is_last_peer:
continue
filtered_next_batch.append(req)

self._complete_local_middle_chunk(chunked_rid)
return filtered_next_batch
return filter_middle_chunk_next_batch(
self,
requests,
next_batch,
release_callback=self.cache_manager.release_request,
)

def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True):
"""Process a batch of requests in MLX."""
Expand Down
Loading
Loading