Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/parallax/server/cache/linear_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import mlx.core as mx

Expand Down Expand Up @@ -52,6 +52,14 @@ def __init__(
def get_cache(self) -> Tuple[Optional[mx.array], Optional[mx.array]]:
return self.conv_state_cache, self.linear_state_cache

def get_state_cache_arrays(self) -> List[mx.array]:
arrays = []
if self.conv_state_cache is not None:
arrays.append(self.conv_state_cache)
if self.linear_state_cache is not None:
arrays.append(self.linear_state_cache)
return arrays

def get_indexer_cache(self) -> Optional[mx.array]:
return None

Expand Down
8 changes: 8 additions & 0 deletions src/parallax/server/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ def get_caches(self) -> List[BaseCache]:
"""Returns the list of layer caches."""
return self.caches

def materialize_linear_caches(self):
arrays = []
for cache in self.caches:
if isinstance(cache, LinearCache):
arrays.extend(cache.get_state_cache_arrays())
if arrays:
mx.eval(*arrays)

def match_and_reuse_prefix(self, request_id: str, token_ids: List[int]) -> int:
"""
Match prefix before prefill and reuse existing blocks.
Expand Down
9 changes: 6 additions & 3 deletions src/parallax/server/executor/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,11 @@ def run_loop(self):
# 8. Dispatch to the appropriate destination
if self.is_last_peer and self.is_first_peer:
# Single node: handle locally
self.handle_input_requests(next_batch)
if next_batch:
self.handle_input_requests(next_batch)
elif self.tp_rank == 0:
if not next_batch:
continue
# Send output to next peer
self.send_to_peer_socket.send_multipart(
[
Expand Down Expand Up @@ -833,7 +836,7 @@ def _prepare_next_single_request(
request_id=request.request_id,
status=RequestStatus.DECODING,
current_position=request.total_length + 1,
input_ids=request.input_ids,
input_ids=request.origin_input_ids,
hidden_states=hidden_states,
next_token_id=next_token_id,
routing_table=request.routing_table,
Expand All @@ -852,7 +855,7 @@ def _prepare_next_single_request(
request_id=request.request_id,
status=RequestStatus.DECODING, # Last peer always changes status to DECODING
current_position=request.total_length,
input_ids=request.input_ids,
input_ids=request.origin_input_ids,
hidden_states=hidden_states,
next_token_id=next_token_id,
routing_table=request.routing_table,
Expand Down
1 change: 1 addition & 0 deletions src/parallax/server/executor/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def create_executor_config(args: argparse.Namespace, shared_state=None, conn=Non
"max_loaded_loras": args.max_loaded_loras,
"enable_weight_refit": args.enable_weight_refit,
"weight_refit_mode": args.weight_refit_mode,
"chunked_prefill_size": getattr(args, "chunked_prefill_size", None),
}

if args.gpu_backend == "sglang":
Expand Down
120 changes: 112 additions & 8 deletions src/parallax/server/executor/mlx_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from parallax.server.sampling.sampler import SamplingBatchInfo
from parallax.server.shard_loader import MLXModelLoader
from parallax.utils.mac_prefill_adder import AddReqResult, MACPrefillAdder
from parallax.utils.utils import (
combine_padding_and_causal_masks,
create_causal_mask,
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
# Weight Refit
enable_weight_refit: Optional[bool] = False,
weight_refit_mode: Optional[str] = "disk",
chunked_prefill_size: Optional[int] = None,
# Pipe communication
conn: Optional[List[Any]] = [],
):
Expand Down Expand Up @@ -187,6 +189,15 @@ def __init__(
kv_block_size,
self.num_shard_layers,
)
if chunked_prefill_size is not None:
if chunked_prefill_size < 0:
raise ValueError("chunked_prefill_size must be non-negative")
if chunked_prefill_size == 0:
chunked_prefill_size = None
elif not enable_prefix_cache:
enable_prefix_cache = True
logger.info("Prefix cache enabled automatically for MLX chunked prefill")

self.cache_manager = CacheManager(
num_layers=self.num_shard_layers,
num_kv_heads=num_key_value_heads // tp_size,
Expand All @@ -208,6 +219,18 @@ def __init__(
enable_prefix_cache=enable_prefix_cache,
sliding_window=sliding_window,
)

if chunked_prefill_size is not None:
self.chunked_prefill_size = (
(chunked_prefill_size + self.cache_manager.block_size - 1)
// self.cache_manager.block_size
* self.cache_manager.block_size
)
else:
self.chunked_prefill_size = None
self.cache_manager.chunked_prefill_size = self.chunked_prefill_size
self.cache_manager.defer_prefill_allocation = self.chunked_prefill_size is not None

super().__init__(
start_layer=start_layer,
end_layer=end_layer,
Expand Down Expand Up @@ -235,6 +258,7 @@ def __init__(

# Prefix Cache Manager
self.enable_prefix_cache = enable_prefix_cache
self.chunked_req = None
# self.prefix_cache = RadixCache(
# num_kv_heads=num_key_value_heads,
# head_dim=head_dim,
Expand Down Expand Up @@ -273,6 +297,26 @@ def _tensor_parallel_broadcast_pyobj(self, broadcast_obj):
data = pickle.loads(np.array(data_arr).tobytes())
return data

def _complete_local_middle_chunk(self, request_id: str):
"""Finish local bookkeeping after this peer runs a non-final prefill chunk."""
if self.chunked_req is None or self.chunked_req.rid != request_id:
return

self.chunked_req.is_chunked = False
self.cache_manager.release_request(request_id)

if self.is_first_peer:
original_req = self.scheduler.get_running_request(request_id)
if original_req is None:
logger.warning(
f"Completed local chunk for {request_id}, but no running request was found."
)
return
original_req.status = RequestStatus.PREFILLING
self.scheduler.enque_request(original_req)
else:
self.scheduler.evict_request(request_id)

def handle_input_requests(self, requests: List[Request]):
"""Update requests states and status in scheduler and cache manager."""
if self.tp_size > 1:
Expand Down Expand Up @@ -341,11 +385,6 @@ def handle_input_requests(self, requests: List[Request]):
req, IntermediateRequest
), "Non-first peers must receive IntermediateRequests."
if req.is_finished or req.hidden_states is None:
if self.enable_prefix_cache:
keys, values = self.cache_manager.gather_kv_cache(req.request_id)
self.prefix_cache.cache_finished_request(req, keys, values)
self.prefix_cache.evict_request(req.request_id)

self.cache_manager.release_request(req.request_id)
logger.debug(
f"Released resources for finished request {req.request_id}, "
Expand All @@ -358,6 +397,29 @@ def handle_input_requests(self, requests: List[Request]):
# This is an active request, add it to the scheduler queue to be processed.
self.scheduler.enque_request(req)

def prepare_next_batch_requests(
self, requests: List[Request], batch_output: Any, context_lengths: Any
) -> List[Request]:
next_batch = super().prepare_next_batch_requests(requests, batch_output, context_lengths)
if (
self.chunked_req is None
or not self.chunked_req.is_chunked
or self.chunked_req.rid not in [req.request_id for req in requests]
):
return next_batch

chunked_rid = self.chunked_req.rid
filtered_next_batch = []
for req in next_batch:
if req.request_id == chunked_rid:
req.status = RequestStatus.PREFILLING
if self.is_last_peer:
continue
filtered_next_batch.append(req)

self._complete_local_middle_chunk(chunked_rid)
return filtered_next_batch

def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens: bool = True):
"""Process a batch of requests in MLX."""
# Run model and get updated cache
Expand All @@ -380,6 +442,8 @@ def process_batch(self, prepared_inputs: Dict[str, Any], return_decoded_tokens:
prefix_lens=prepared_inputs.get("prefix_lens"), # For RoPE offset in prefix cache
)
mx.eval(hidden_states)
if is_prefill_batch and self.chunked_prefill_size is not None:
self.cache_manager.materialize_linear_caches()

if logger.isEnabledFor(logging.DEBUG):
forward_time = (time.time() - start_time) * 1000
Expand Down Expand Up @@ -507,6 +571,46 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
if batch_size == 0:
return None

if self.chunked_prefill_size is not None:
original_batched_requests = batched_requests
adder = MACPrefillAdder(
self.cache_manager.block_size, self.chunked_prefill_size, self.cache_manager
)
chunked_rid = self.chunked_req.rid if self.chunked_req is not None else None

if self.chunked_req is not None and chunked_rid in [
req.request_id for req in original_batched_requests
]:
self.chunked_req = [
req for req in original_batched_requests if req.request_id == chunked_rid
][0]
self.chunked_req = adder.add_chunked_req(self.chunked_req)

for old_req in original_batched_requests:
if chunked_rid is not None and old_req.request_id == chunked_rid:
continue
res = adder.add_one_req(old_req)
if res != AddReqResult.CONTINUE:
break

if adder.new_chunked_req is not None:
self.chunked_req = adder.new_chunked_req

if self.chunked_req is not None and self.chunked_req.rid in [
req.request_id for req in original_batched_requests
]:
self.chunked_req.is_chunked = True

can_run_by_id = {req.request_id: req for req in adder.can_run_list}
batched_requests = [
can_run_by_id[req.request_id]
for req in original_batched_requests
if req.request_id in can_run_by_id
]
batch_size = len(batched_requests)
if batch_size == 0:
return None

h_or_tokens_list = []
block_tables_list = []
context_lengths_list = []
Expand Down Expand Up @@ -555,14 +659,14 @@ def _prepare_prefill_batch(self, batched_requests: List[Request]) -> Dict[str, A
actual_processed_lengths_list.append(len(req.input_ids))
else:
if matched_tokens > 0 and self.enable_prefix_cache:
# Skip the prefix hidden states that correspond to cached tokens
new_hidden = req.hidden_states[matched_tokens:]
if new_hidden.shape[0] == 0:
keep_len = req.total_length - matched_tokens
if keep_len <= 0:
# All tokens cached - keep the last hidden state
new_hidden = req.hidden_states[-1:]
prefix_lens_list[-1] = matched_tokens - 1
actual_processed_lengths_list.append(1)
else:
new_hidden = req.hidden_states[-keep_len:]
actual_processed_lengths_list.append(new_hidden.shape[0])
h_or_tokens_list.append(new_hidden)
else:
Expand Down
1 change: 1 addition & 0 deletions src/parallax/server/executor/sglang_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
# Weight Refit
enable_weight_refit: Optional[bool] = False,
weight_refit_mode: Optional[str] = "disk",
chunked_prefill_size: Optional[int] = None,
# Pipe communication
conn: Optional[List[Any]] = [],
):
Expand Down
1 change: 1 addition & 0 deletions src/parallax/server/executor/vllm_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
# Weight Refit
enable_weight_refit: Optional[bool] = False,
weight_refit_mode: Optional[str] = "disk",
chunked_prefill_size: Optional[int] = None,
# Routed experts
enable_return_routed_experts: bool = False,
# Pipe communication
Expand Down
26 changes: 22 additions & 4 deletions src/parallax/server/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,16 @@ def __init__(
self.last_updated_time: Optional[float] = None
self.lora_id: Optional[str] = None
self.lora_path = lora_path
self.rid = self.request_id
self.is_chunked = False
# Full prompt tokens as received from the client. Chunked prefill may
# temporarily shorten input_ids to the prefix visible to the current
# chunk, so keep this immutable source for later chunks and decode.
self.origin_input_ids = list(input_ids) if input_ids is not None else None
# Optional logical sequence length override for chunked prefill. When
# set, total_length reports the end offset of the current prompt chunk
# instead of the final prompt+output length.
self._effective_total_length: Optional[int] = None

@property
def is_finished(self) -> bool:
Expand Down Expand Up @@ -201,7 +211,9 @@ def output_length(self) -> int:

@property
def total_length(self) -> int:
"""Total length of the sequence (input + output)."""
"""Logical sequence length for scheduling/cache allocation."""
if self._effective_total_length is not None:
return self._effective_total_length
return self.prompt_len + self.output_length

def get_model_input_for_first_peer(self) -> List[int]:
Expand All @@ -226,6 +238,10 @@ def commit_new_token(self, token_id: int):
)
return

self._effective_total_length = None
if self.origin_input_ids is not None:
self.input_ids = self.origin_input_ids

self.output_ids.append(token_id)

# Finishing condition checks are now handled by the Scheduler.
Expand Down Expand Up @@ -302,7 +318,9 @@ def input_length(self) -> int:

@property
def total_length(self) -> int:
"""Total length of the sequence (input + output)."""
"""Logical sequence length for scheduling/cache allocation."""
if self._effective_total_length is not None:
return self._effective_total_length
return self.current_position

@classmethod
Expand Down Expand Up @@ -337,7 +355,7 @@ def from_initial_request(
return IntermediateRequest(
request_id=initial_request.request_id,
status=initial_request.status,
input_ids=initial_request.input_ids,
input_ids=initial_request.origin_input_ids,
next_token_id=next_token_id,
current_position=initial_request.total_length,
hidden_states=hidden_states,
Expand All @@ -364,7 +382,7 @@ def from_intermediate_request(
request_id=old_request.request_id,
status=old_request.status,
current_position=old_request.total_length,
input_ids=old_request.input_ids,
input_ids=old_request.origin_input_ids,
next_token_id=old_request.next_token_id,
hidden_states=new_hidden_states,
routing_table=old_request.routing_table,
Expand Down
Loading
Loading