Skip to content

Commit 5e5c90d

Browse files
committed
[feat]pd disaggregated support cross-machine
Signed-off-by: weiguihua2 <[email protected]>
1 parent 2497bbb commit 5e5c90d

File tree

2 files changed

+92
-7
lines changed

2 files changed

+92
-7
lines changed

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
from vllm import envs
2525
from vllm.config import VllmConfig
2626
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
27-
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
27+
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
28+
KVConnectorRole)
2829
from vllm.distributed.parallel_state import (
2930
get_decode_context_model_parallel_rank,
3031
get_decode_context_model_parallel_world_size, get_pp_group,
@@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
6465
te_rpc_port: int
6566
kv_caches_base_addr: list[int]
6667
num_blocks: int
68+
local_ip: str = ""
6769

6870

6971
@dataclass
@@ -75,6 +77,7 @@ class ReqMeta:
7577
remote_engine_id: str
7678
remote_pcp_size: int
7779
remote_dcp_size: int
80+
remote_multi_nodes_meta_mapping: dict[int, dict[str, Any]]
7881

7982

8083
@dataclass
@@ -685,6 +688,8 @@ def add_new_req(
685688
remote_port=kv_transfer_params["remote_port"],
686689
remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1),
687690
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
691+
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
692+
"remote_multi_nodes_meta_mapping", {}),
688693
)
689694

690695

@@ -772,6 +777,30 @@ def wait_for_save(self):
772777
"""MooncakeConnector does not save explicitly."""
773778
pass
774779

780+
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
781+
"""
782+
Get the KVConnector handshake metadata for this connector.
783+
This metadata is used for out-of-band connector handshake
784+
between P/D workers.
785+
786+
Returns:
787+
KVConnectorHandshakeMetadata: the handshake metadata.
788+
None if no handshake metadata is available.
789+
"""
790+
assert self.connector_worker is not None
791+
return self.connector_worker.xfer_handshake_metadata
792+
793+
def set_xfer_handshake_metadata(
794+
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
795+
"""
796+
Set the KV connector handshake metadata for this connector.
797+
798+
Args:
799+
metadata (dict): the handshake metadata to set.
800+
"""
801+
assert self.connector_scheduler is not None
802+
self.connector_scheduler.set_xfer_handshake_metadata(metadata)
803+
775804

776805
class MooncakeConnectorScheduler:
777806
"""Implementation of Scheduler side methods"""
@@ -805,6 +834,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
805834
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
806835
self._reqs_need_send: dict[str, float] = {}
807836

837+
# master-slave meta information for cross-nodes
838+
self.multi_nodes_meta_mapping: dict[int, dict[str, Any]] = {}
839+
808840
def get_num_new_matched_tokens(
809841
self, request: "Request",
810842
num_computed_tokens: int) -> tuple[int, bool]:
@@ -928,8 +960,23 @@ def request_finished(
928960
remote_pcp_size=self.pcp_size,
929961
remote_dcp_size=self.dcp_size,
930962
last_token_id=request.output_token_ids[-1],
963+
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
931964
)
932965

966+
def set_xfer_handshake_metadata(
967+
self, metadata: dict[int, KVConnectorHandshakeMetadata]) -> None:
968+
"""
969+
Set the KV connector handshake metadata for this connector.
970+
971+
Args:
972+
metadata (dict): the handshake metadata to set.
973+
"""
974+
for local_rank, rank_metadata in metadata.items():
975+
self.multi_nodes_meta_mapping[str(local_rank)] = {
976+
"host": rank_metadata.local_ip,
977+
"engine_id": rank_metadata.engine_id,
978+
}
979+
933980

934981
class MooncakeConnectorWorker:
935982
"""Implementation of Worker side methods"""
@@ -989,6 +1036,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
9891036
self.kv_send_thread: Optional[KVCacheSendingThread] = None
9901037
self.kv_recv_thread: Optional[KVCacheRecvingThread] = None
9911038

1039+
# Handshake metadata of this worker
1040+
self.xfer_handshake_metadata: MooncakeAgentMetadata | None = None
1041+
9921042
# kv_transfer variables
9931043
self.vllm_config = vllm_config
9941044
self.block_size = vllm_config.cache_config.block_size
@@ -1118,7 +1168,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11181168
te_rpc_port=self.te_rpc_port,
11191169
kv_caches_base_addr=kv_caches_base_addr,
11201170
num_blocks=self.num_blocks,
1171+
local_ip=get_ip(),
11211172
)
1173+
self.xfer_handshake_metadata = metadata
11221174

11231175
ready_event = threading.Event()
11241176
if self.kv_role == 'kv_producer':
@@ -1266,13 +1318,18 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12661318
continue
12671319
for i in range(self.tp_num_need_pulls):
12681320
assert self.kv_recv_thread is not None
1321+
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
1322+
meta.remote_port,
1323+
remote_handshake_port_list[pcp_dcp_rank][i],
1324+
meta.remote_host, meta.remote_engine_id,
1325+
meta.remote_multi_nodes_meta_mapping)
12691326
self.kv_recv_thread.add_request(
12701327
request_id=req_id,
12711328
local_block_ids=local_block_ids_list[pcp_dcp_rank],
12721329
remote_block_ids=remote_block_ids_list[
12731330
pcp_dcp_rank],
1274-
remote_engine_id=meta.remote_engine_id,
1275-
remote_host=meta.remote_host,
1331+
remote_engine_id=remote_engine_id,
1332+
remote_host=remote_host,
12761333
remote_handshake_port=remote_handshake_port_list[
12771334
pcp_dcp_rank][i],
12781335
offset=i,
@@ -1287,12 +1344,16 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12871344
for x in choosen_rank_list]
12881345
for i in range(self.tp_num_need_pulls * self._prefill_pp_size):
12891346
assert self.kv_recv_thread is not None
1347+
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
1348+
meta.remote_port, remote_handshake_port_list[i][0],
1349+
meta.remote_host, meta.remote_engine_id,
1350+
meta.remote_multi_nodes_meta_mapping)
12901351
self.kv_recv_thread.add_request(
12911352
request_id=req_id,
12921353
local_block_ids=meta.local_block_ids,
12931354
remote_block_ids=meta.remote_block_ids,
1294-
remote_engine_id=meta.remote_engine_id,
1295-
remote_host=meta.remote_host,
1355+
remote_engine_id=remote_engine_id,
1356+
remote_host=remote_host,
12961357
remote_handshake_port=remote_handshake_port_list[i][0],
12971358
offset=i,
12981359
tp_num_need_pulls=self.tp_num_need_pulls,
@@ -1307,6 +1368,18 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
13071368
else:
13081369
self.kv_send_thread.add_not_transfer_request(req_id)
13091370

1371+
def _get_remote_host_info_by_port(self, base_port: int,
1372+
remote_handshake_port: int,
1373+
remote_host: str, remote_engine_id: str,
1374+
remote_multi_nodes_meta_mapping: dict):
1375+
rank = str(remote_handshake_port - base_port)
1376+
if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping.get(
1377+
rank, None) is None:
1378+
return remote_host, remote_engine_id
1379+
info = remote_multi_nodes_meta_mapping[rank]
1380+
return info.get("host", remote_host), info.get("engine_id",
1381+
remote_engine_id)
1382+
13101383
def _prefill_get_remote_rank(self, req_id: str) -> List[int]:
13111384
return sum(self._get_remote_ranks_for_req(req_id), [])
13121385

vllm_ascend/worker/worker_v1.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
from vllm.distributed import (ensure_model_parallel_initialized,
3232
init_distributed_environment)
3333
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
34-
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
34+
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
35+
get_kv_transfer_group,
36+
has_kv_transfer_group)
3537
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
3638
from vllm.logger import logger
3739
from vllm.lora.request import LoRARequest
@@ -374,7 +376,17 @@ def get_model(self) -> nn.Module:
374376
return self.model_runner.get_model()
375377

376378
def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
377-
return None
379+
"""Get KV connector metadata from this worker if available."""
380+
if not has_kv_transfer_group():
381+
return None
382+
383+
connector = get_kv_transfer_group()
384+
385+
# Return None for connectors that don't need to exchange handshake
386+
# metadata across workers.
387+
if (metadata := connector.get_handshake_metadata()) is None:
388+
return None
389+
return {self.rank: metadata}
378390

379391
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
380392
return self.model_runner.get_kv_cache_spec()

0 commit comments

Comments
 (0)