Skip to content

Commit ce6ef50

Browse files
authored
feat: add sglang chunked prefill support (#470)
1 parent 3888eed commit ce6ef50

11 files changed

Lines changed: 576 additions & 113 deletions

src/parallax/server/executor/base_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
# Weight Refit
9595
enable_weight_refit: Optional[bool] = False,
9696
weight_refit_mode: Optional[str] = "disk",
97+
chunked_prefill_size: Optional[int] = None,
9798
# Pipe communication
9899
conn: Optional[List[Any]] = [],
99100
):
@@ -108,6 +109,7 @@ def __init__(
108109
self.finished_batch = []
109110
self.start_layer = start_layer
110111
self.end_layer = end_layer
112+
self.max_num_tokens_per_batch = max_num_tokens_per_batch
111113
self._should_stop = False # Flag to gracefully stop the executor
112114
# Reference to shared state for layer reallocation detection (when in subprocess mode)
113115
if shared_state is not None:
@@ -181,6 +183,7 @@ def __init__(
181183
cache_manager=self.cache_manager if self.device == "mlx" else None,
182184
request_timeout_s=request_timeout_s,
183185
shared_state=self.shared_state,
186+
chunked_prefill_size=chunked_prefill_size,
184187
)
185188
logger.debug(
186189
f"Scheduler initialized (max_batch_size={max_batch_size}, max_tokens={max_num_tokens_per_batch}, wait_ms={scheduler_wait_ms})"

src/parallax/server/executor/mlx_executor.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
)
2121
from parallax.server.sampling.sampler import SamplingBatchInfo
2222
from parallax.server.shard_loader import MLXModelLoader
23+
from parallax.utils.chunked_prefill import (
24+
complete_local_middle_chunk,
25+
filter_middle_chunk_next_batch,
26+
)
2327
from parallax.utils.mac_prefill_adder import AddReqResult, MACPrefillAdder
2428
from parallax.utils.utils import (
2529
combine_padding_and_causal_masks,
@@ -253,6 +257,7 @@ def __init__(
253257
shared_state=shared_state,
254258
enable_weight_refit=enable_weight_refit,
255259
weight_refit_mode=weight_refit_mode,
260+
chunked_prefill_size=self.chunked_prefill_size,
256261
conn=conn,
257262
)
258263

@@ -299,23 +304,11 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj):
299304

300305
def _complete_local_middle_chunk(self, request_id: str):
301306
"""Finish local bookkeeping after this peer runs a non-final prefill chunk."""
302-
if self.chunked_req is None or self.chunked_req.rid != request_id:
303-
return
304-
305-
self.chunked_req.is_chunked = False
306-
self.cache_manager.release_request(request_id)
307-
308-
if self.is_first_peer:
309-
original_req = self.scheduler.get_running_request(request_id)
310-
if original_req is None:
311-
logger.warning(
312-
f"Completed local chunk for {request_id}, but no running request was found."
313-
)
314-
return
315-
original_req.status = RequestStatus.PREFILLING
316-
self.scheduler.enque_request(original_req)
317-
else:
318-
self.scheduler.evict_request(request_id)
307+
complete_local_middle_chunk(
308+
self,
309+
request_id,
310+
release_callback=self.cache_manager.release_request,
311+
)
319312

320313
def handle_input_requests(self, requests: List[Request]):
321314
"""Update requests states and status in scheduler and cache manager."""
@@ -401,24 +394,12 @@ def prepare_next_batch_requests(
401394
self, requests: List[Request], batch_output: Any, context_lengths: Any
402395
) -> List[Request]:
403396
next_batch = super().prepare_next_batch_requests(requests, batch_output, context_lengths)
404-
if (
405-
self.chunked_req is None
406-
or not self.chunked_req.is_chunked
407-
or self.chunked_req.rid not in [req.request_id for req in requests]
408-
):
409-
return next_batch
410-
411-
chunked_rid = self.chunked_req.rid
412-
filtered_next_batch = []
413-
for req in next_batch:
414-
if req.request_id == chunked_rid:
415-
req.status = RequestStatus.PREFILLING
416-
if self.is_last_peer:
417-
continue
418-
filtered_next_batch.append(req)
419-
420-
self._complete_local_middle_chunk(chunked_rid)
421-
return filtered_next_batch
397+
return filter_middle_chunk_next_batch(
398+
self,
399+
requests,
400+
next_batch,
401+
release_callback=self.cache_manager.release_request,
402+
)
422403

423404
def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True):
424405
"""Process a batch of requests in MLX."""

0 commit comments

Comments
 (0)