Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 297 additions & 0 deletions examples/experimental/pd_disaggregation/toy_proxy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse
import itertools
import logging
import os
import uuid
from contextlib import asynccontextmanager

import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []

# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f"http://{host}:{port}/v1"
app.state.prefill_clients.append(
{
"client": httpx.AsyncClient(
timeout=None,
base_url=prefiller_base_url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"host": host,
"port": port,
"id": i,
}
)

# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f"http://{host}:{port}/v1"
app.state.decode_clients.append(
{
"client": httpx.AsyncClient(
timeout=None,
base_url=decoder_base_url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"host": host,
"port": port,
"id": i,
}
)

# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))

print(
f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients."
)

yield

# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info["client"].aclose()

for client_info in app.state.decode_clients:
await client_info["client"].aclose()


# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument("--port", type=int, default=8000)
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
parser.add_argument("--host", type=str, default="127.0.0.1")

# For prefiller instances
parser.add_argument(
"--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"],
)
parser.add_argument(
"--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100]
)

# For decoder instances
parser.add_argument(
"--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"]
)
parser.add_argument(
"--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200]
)

args = parser.parse_args()

# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports"
)

if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError("Number of decoder hosts must match number of decoder ports")

# Create tuples of (host, port) for each service type
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))

return args


def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.

Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'

Returns:
The next client to use
"""
if service_type == "prefill":
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == "decode":
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")


async def send_request_to_service(
client_info: dict, endpoint: str, req_data: dict, request_id: str
):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "max_completion_tokens" in req_data:
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}

response = await client_info["client"].post(
endpoint, json=req_data, headers=headers
)
response.raise_for_status()

# read/consume the response body to release the connection
# otherwise, it would http.ReadError
await response.aread()

return response


async def stream_service_response(
client_info: dict, endpoint: str, req_data: dict, request_id: str
):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}

async with client_info["client"].stream(
"POST", endpoint, json=req_data, headers=headers
) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk


async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())

# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, "prefill")

# Send request to prefill service
response = await send_request_to_service(
prefill_client_info, api, req_data, request_id
)

# Extract the needed fields
response_json = response.json()
await response.aclose() # CRITICAL: Release connection back to pool
kv_transfer_params = response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params

# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, "decode")

logger.debug("Using %s %s", prefill_client_info, decode_client_info)

# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(
decode_client_info, api, req_data, request_id=request_id
):
yield chunk

return StreamingResponse(generate_stream(), media_type="application/json")

except Exception as e:
import sys
import traceback

exc_info = sys.exc_info()
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise


@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)


@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)


@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients),
}


if __name__ == "__main__":
global global_args
global_args = parse_args()

import uvicorn

uvicorn.run(app, host=global_args.host, port=global_args.port)
Loading
Loading