Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 78 additions & 5 deletions vllm_ascend/distributed/mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
KVConnectorRole)
from vllm.distributed.parallel_state import (
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size, get_pp_group,
Expand Down Expand Up @@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
te_rpc_port: int
kv_caches_base_addr: list[int]
num_blocks: int
local_ip: str = ""


@dataclass
Expand All @@ -75,6 +77,7 @@ class ReqMeta:
remote_engine_id: str
remote_pcp_size: int
remote_dcp_size: int
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]


@dataclass
Expand Down Expand Up @@ -685,6 +688,8 @@ def add_new_req(
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
"remote_multi_nodes_meta_mapping", {}),
)


Expand Down Expand Up @@ -772,6 +777,30 @@ def wait_for_save(self):
"""MooncakeConnector does not save explicitly."""
pass

def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
This metadata is used for out-of-band connector handshake
between P/D workers.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
None if no handshake metadata is available.
"""
assert self.connector_worker is not None
return self.connector_worker.xfer_handshake_metadata

def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
"""
Set the KV connector handshake metadata for this connector.

Args:
metadata (dict): the handshake metadata to set.
"""
assert self.connector_scheduler is not None
self.connector_scheduler.set_xfer_handshake_metadata(metadata)


class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
Expand Down Expand Up @@ -805,6 +834,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}

# master-slave meta information for cross-nodes
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}

def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
Expand Down Expand Up @@ -928,8 +960,23 @@ def request_finished(
remote_pcp_size=self.pcp_size,
remote_dcp_size=self.dcp_size,
last_token_id=request.output_token_ids[-1],
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
)

def set_xfer_handshake_metadata(
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
"""
Set the KV connector handshake metadata for this connector.

Args:
metadata (dict): the handshake metadata to set.
"""
for local_rank, rank_metadata in metadata.items():
self.multi_nodes_meta_mapping[str(local_rank)] = {
"host": rank_metadata.local_ip,
"engine_id": rank_metadata.engine_id,
}


class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
Expand Down Expand Up @@ -989,6 +1036,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.kv_send_thread: Optional[KVCacheSendingThread] = None
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None

# Handshake metadata of this worker
self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None

# kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
Expand Down Expand Up @@ -1118,7 +1168,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
te_rpc_port=self.te_rpc_port,
kv_caches_base_addr=kv_caches_base_addr,
num_blocks=self.num_blocks,
local_ip=get_ip(),
)
self.xfer_handshake_metadata = metadata

ready_event = threading.Event()
if self.kv_role == 'kv_producer':
Expand Down Expand Up @@ -1266,13 +1318,18 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
continue
for i in range(self.tp_num_need_pulls):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port,
remote_handshake_port_list[pcp_dcp_rank][i],
meta.remote_host, meta.remote_engine_id,
meta.remote_multi_nodes_meta_mapping)
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[
pcp_dcp_rank],
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_engine_id=remote_engine_id,
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[
pcp_dcp_rank][i],
offset=i,
Expand All @@ -1287,12 +1344,16 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
for x in choosen_rank_list]
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
meta.remote_port, remote_handshake_port_list[i][0],
meta.remote_host, meta.remote_engine_id,
meta.remote_multi_nodes_meta_mapping)
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_engine_id=remote_engine_id,
remote_host=remote_host,
remote_handshake_port=remote_handshake_port_list[i][0],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
Expand All @@ -1307,6 +1368,18 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
else:
self.kv_send_thread.add_not_transfer_request(req_id)

def _get_remote_host_info_by_port(self, base_port: int,
remote_handshake_port: int,
remote_host: str, remote_engine_id: str,
remote_multi_nodes_meta_mapping: dict):
rank = str(remote_handshake_port - base_port)
if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(
rank, None) is None:
return remote_host, remote_engine_id
info = remote_multi_nodes_meta_mapping[rank]
return info.get("host", remote_host), info.get("engine_id",
remote_engine_id)

def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
return sum(self._get_remote_ranks_for_req(req_id), [])

Expand Down
16 changes: 14 additions & 2 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -374,7 +376,17 @@ def get_model(self) -> nn.Module:
return self.model_runner.get_model()

def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
return None
"""Get KV connector metadata from this worker if available."""
if not has_kv_transfer_group():
return None

connector = get_kv_transfer_group()

# Return None for connectors that don't need to exchange handshake
# metadata across workers.
if (metadata := connector.get_handshake_metadata()) is None:
return None
return {self.rank: metadata}

def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
return self.model_runner.get_kv_cache_spec()
Expand Down
Loading