Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/ray/llm/_internal/serve/core/configs/llm_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Expand Down Expand Up @@ -50,6 +51,11 @@
from ray.llm._internal.serve.observability.logging import get_logger
from ray.serve._private.config import DeploymentConfig, handle_num_replicas_auto

if TYPE_CHECKING:
from ray.llm._internal.serve.engines.vllm.kv_transfer.base import (
BaseConnectorBackend,
)

transformers = try_import("transformers")

ModelT = TypeVar("ModelT", bound=BaseModel)
Expand Down Expand Up @@ -255,6 +261,7 @@ def validate_server_cls(cls, value):
_model_architecture: str = PrivateAttr("UNSPECIFIED")
_engine_config: EngineConfigType = PrivateAttr(None)
_callback_instance: Optional[CallbackBase] = PrivateAttr(None)
_kv_connector_backend: Optional["BaseConnectorBackend"] = PrivateAttr(None)

def _load_hf_config(self, model_id_or_path: str, trust_remote_code: bool = False):
"""Load the HuggingFace config for a model.
Expand Down Expand Up @@ -626,6 +633,20 @@ def _setup_kv_connector_backend(self):
kv_connector, self
)
kv_connector_backend.setup()
# 3. Stash the instance so the P/D orchestrator can reach the connector's
# coordination protocol (request shaping, peer binding, handoff
# discipline) without re-creating it. May be None on configs that never
# call setup_engine_backend(); the orchestrator falls back to the factory.
self._kv_connector_backend = kv_connector_backend

@property
def kv_connector_backend(self) -> Optional["BaseConnectorBackend"]:
"""The KV-connector backend instance created by ``setup_engine_backend``.

Returns None if no KV transfer connector is configured, or if the
backend has not been set up yet on this config copy.
"""
return self._kv_connector_backend


class DiskMultiplexConfig(BaseModelExtended):
Expand Down
154 changes: 153 additions & 1 deletion python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,52 @@
import abc
import random
import string
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from ray import serve

if TYPE_CHECKING:
from ray.llm._internal.serve.core.configs.llm_config import LLMConfig
from ray.llm._internal.serve.core.configs.openai_api_models import (
ChatCompletionRequest,
CompletionRequest,
)

# The two OpenAI request models the P/D orchestrator shapes. Defined under
# TYPE_CHECKING (and used as a string annotation) to avoid an import cycle
# between this module and the config/openai-models modules.
RequestType = Union[ChatCompletionRequest, CompletionRequest]


class BaseConnectorBackend(abc.ABC):
# ---- P/D coordination protocol ----
#
# These class attributes and methods let the P/D orchestrator
# (``PDOrchestratorMixin``) delegate request shaping, peer addressing, and
# handoff discipline to the connector. They are connector-agnostic: a
# connector picks a quadrant of (``requires_peer_binding``,
# ``concurrent_handoff``) and implements ``prepare_prefill_request`` /
# ``prepare_decode_request`` accordingly.
#
# ``requires_peer_binding``:
# * False -> the orchestrator dispatches prefill via the standard handle
# path; the peer (if any) is resolved post-hoc from the prefill response.
# * True -> the orchestrator selects the prefill replica first
# (``choose_replica``) and passes its ``replica_metadata`` to the backend
# as ``peer`` (pre-dispatch addressing).
#
# ``concurrent_handoff``:
# * False -> prefill runs to its first chunk before local decode starts
# (sequential handoff).
# * True -> prefill dispatch and local decode run concurrently.
#
# The two flags are independent. For example: a standard (pull-on-response)
# connector is (False, False); a push-based, request-id-addressed connector
# is (True, True); a pull-based request-id-addressed connector is
# (True, False).
requires_peer_binding: bool = False
concurrent_handoff: bool = False
Comment thread
kouroshHakha marked this conversation as resolved.

def __init__(self, llm_config: "LLMConfig"):
"""Base class for connector backends.

Expand Down Expand Up @@ -67,9 +104,124 @@ def _compute_port_offset(self) -> int:
num_devices = engine_config.num_devices
return rc.rank.rank * num_devices

@abc.abstractmethod
def prepare_prefill_request(
self, *, request: "RequestType", peer: Optional[Dict[str, Any]]
) -> "RequestType":
"""Shape the request sent to the remote prefill engine.

Args:
request: The incoming chat/completion request.
peer: The selected prefill replica's ``replica_metadata`` dict when
the connector opted into pre-dispatch peer binding
(``requires_peer_binding=True``), else None.

Returns:
A new request object to dispatch to the prefill engine.
"""
...

@abc.abstractmethod
def prepare_decode_request(
self,
*,
request: "RequestType",
peer: Optional[Dict[str, Any]],
prefill_response: Optional[Any],
) -> "RequestType":
"""Shape the request run on the local decode engine.

Args:
request: The incoming chat/completion request.
peer: The selected prefill replica's ``replica_metadata`` dict when
the connector opted into pre-dispatch peer binding, else None.
prefill_response: The captured prefill response chunk whose
``kv_transfer_params`` may be forwarded, or None when no chunk is
captured before decode starts (concurrent-handoff mode).

Returns:
A new request object to run on the local decode engine.
"""
...

def setup(self) -> None:
"""Setup the connector backend.

This method is called to setup the connector backend.
"""
pass


class DefaultPDProtocolMixin:
"""The default P/D protocol policy: no peer binding, sequential handoff.

Implements ``prepare_prefill_request`` / ``prepare_decode_request`` for
connectors that follow the standard policy: the prefill engine is told to
produce KV for a remote decode (clamped to a single non-streaming token),
and the decode engine forwards the ``kv_transfer_params`` that the prefill
engine returned on its first response chunk.

Mix this in *before* ``BaseConnectorBackend`` in a backend's bases so its
concrete methods satisfy the abstract methods.
"""

def prepare_prefill_request(
self, *, request: "RequestType", peer: Optional[Dict[str, Any]]
) -> "RequestType":
"""Shape the prefill request under the default P/D protocol policy.

Deep-copies the request, stamps the standard ``kv_transfer_params`` that
tell the prefill engine to produce KV for a remote decode, and clamps it
to a single, non-streaming token. ``peer`` is ignored.
"""
assert (
getattr(request, "kv_transfer_params", None) is None
), "kv_transfer_params should be empty before orchestrator"
prefill_request = request.model_copy(deep=True)
prefill_request.kv_transfer_params = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
prefill_request.max_tokens = 1
if hasattr(prefill_request, "max_completion_tokens"):
prefill_request.max_completion_tokens = 1
prefill_request.stream = False
if hasattr(prefill_request, "stream_options"):
prefill_request.stream_options = None
return prefill_request

def prepare_decode_request(
self,
*,
request: "RequestType",
peer: Optional[Dict[str, Any]],
prefill_response: Optional[Any],
) -> "RequestType":
"""Shape the decode request under the default P/D protocol policy.

Deep-copies the request and, only when a prefill response chunk was
captured, forwards its ``kv_transfer_params`` so the decode engine
pulls/receives the KV produced by prefill. In concurrent-handoff mode
``prefill_response`` is None and the request is left unmodified. ``peer``
is ignored.
"""
decode_request = request.model_copy(deep=True)
if prefill_response is not None:
decode_request.kv_transfer_params = prefill_response.kv_transfer_params
return decode_request
Comment thread
kouroshHakha marked this conversation as resolved.


class DefaultConnectorBackend(DefaultPDProtocolMixin, BaseConnectorBackend):
"""Concrete connector backend using the default P/D protocol policy.

Used as the factory fallback for connectors that are not registered with a
dedicated backend class: they get a no-op ``setup()`` and the default
request-shaping policy. ``BaseConnectorBackend`` is abstract, so the factory
must return a concrete class like this one.
"""

pass
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ray.llm._internal.serve.engines.vllm.kv_transfer.base import (
BaseConnectorBackend,
DefaultConnectorBackend,
)
from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.serve.utils.registry import get_registry
Expand Down Expand Up @@ -56,9 +57,11 @@ def get_backend_class(cls, name: str) -> Type["BaseConnectorBackend"]:
"""Get the connector backend class by name.

For registered connectors, returns the registered backend class.
For unregistered connectors, returns BaseConnectorBackend which has
a no-op setup() method, allowing connectors that don't require
Ray Serve orchestration to work without registration.
For unregistered connectors, returns DefaultConnectorBackend (a concrete
backend with a no-op setup() and the default P/D protocol policy),
allowing connectors that don't require Ray Serve orchestration to work
without registration. (BaseConnectorBackend itself is abstract and
cannot be instantiated.)

Args:
name: The name of the connector backend
Expand All @@ -74,9 +77,9 @@ def get_backend_class(cls, name: str) -> Type["BaseConnectorBackend"]:
except ValueError:
logger.warning(
f"Unsupported connector backend: {name}. "
f"Using default: {BaseConnectorBackend.__name__}."
f"Using default: {DefaultConnectorBackend.__name__}."
)
return BaseConnectorBackend
return DefaultConnectorBackend
except Exception as e:
raise ImportError(
f"Failed to load connector backend '{name}': {type(e).__name__}: {e}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ray.llm._internal.serve.engines.vllm.kv_transfer.base import (
BaseConnectorBackend,
DefaultPDProtocolMixin,
)
from ray.llm._internal.serve.observability.logging import get_logger

Expand All @@ -15,7 +16,7 @@ def _check_lmcache_installed():
)


class LMCacheConnectorV1Backend(BaseConnectorBackend):
class LMCacheConnectorV1Backend(DefaultPDProtocolMixin, BaseConnectorBackend):

KV_CONNECTOR_EXTRA_CONFIG_FIELD_NAME = "kv_connector_extra_config"
LMCACHE_RPC_PORT_FIELD_NAME = "lmcache_rpc_port"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List

from ray.llm._internal.serve.engines.vllm.kv_transfer.base import (
BaseConnectorBackend,
Expand All @@ -13,8 +13,18 @@


class MultiConnectorBackend(BaseConnectorBackend):
"""Wraps multiple sub-connectors.

The P/D protocol (``prepare_prefill_request`` / ``prepare_decode_request`` and
the ``requires_peer_binding`` / ``concurrent_handoff`` policy) is delegated to
the *first* (top-most) sub-connector listed in ``connectors`` — that
connector's policy governs request shaping and handoff for the group. Each
sub-connector's ``setup()`` still runs.
"""

def __init__(self, llm_config: "LLMConfig"):
super().__init__(llm_config)
self._connector_backends: List[BaseConnectorBackend] = []

def setup(self) -> None:
"""Setup all connectors listed in the kv_transfer_config."""
Expand Down Expand Up @@ -49,3 +59,29 @@ def setup(self) -> None:
connector_backend_str, sub_llm_config
)
connector_backend.setup()
self._connector_backends.append(connector_backend)

@property
def _primary(self) -> BaseConnectorBackend:
"""The top-most sub-connector, whose protocol governs the group."""
if not self._connector_backends:
raise ValueError(
"MultiConnectorBackend has no sub-connectors; was setup() called?"
)
return self._connector_backends[0]

@property
def requires_peer_binding(self) -> bool:
return bool(self._connector_backends) and self._primary.requires_peer_binding

@property
def concurrent_handoff(self) -> bool:
return bool(self._connector_backends) and self._primary.concurrent_handoff

def prepare_prefill_request(self, *, request, peer):
return self._primary.prepare_prefill_request(request=request, peer=peer)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty MultiConnector crashes prepare calls

Medium Severity

If setup() runs with an empty connectors list, requires_peer_binding and concurrent_handoff read as false, yet prepare_prefill_request / prepare_decode_request still call _primary and raise ValueError. Previously the orchestrator applied default P/D shaping regardless of Multi configuration.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e3ffc5a. Configure here.


def prepare_decode_request(self, *, request, peer, prefill_response):
return self._primary.prepare_decode_request(
request=request, peer=peer, prefill_response=prefill_response
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import ray
from ray.llm._internal.serve.engines.vllm.kv_transfer.base import (
BaseConnectorBackend,
DefaultPDProtocolMixin,
)


class NixlConnectorBackend(BaseConnectorBackend):
class NixlConnectorBackend(DefaultPDProtocolMixin, BaseConnectorBackend):
def _set_side_channel_port(self):
from vllm import envs as vllm_envs

Expand Down
Loading
Loading