2424from vllm import envs
2525from vllm .config import VllmConfig
2626from vllm .distributed .kv_transfer .kv_connector .v1 .base import (
27- KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
27+ KVConnectorBase_V1 , KVConnectorHandshakeMetadata , KVConnectorMetadata ,
28+ KVConnectorRole )
2829from vllm .distributed .parallel_state import (
2930 get_decode_context_model_parallel_rank ,
3031 get_decode_context_model_parallel_world_size , get_pp_group ,
@@ -64,6 +65,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
6465 te_rpc_port : int
6566 kv_caches_base_addr : list [int ]
6667 num_blocks : int
68+ local_ip : str = ""
6769
6870
6971@dataclass
@@ -75,6 +77,7 @@ class ReqMeta:
7577 remote_engine_id : str
7678 remote_pcp_size : int
7779 remote_dcp_size : int
80+ remote_multi_nodes_meta_mapping : dict [int , dict [str , Any ]]
7881
7982
8083@dataclass
@@ -685,6 +688,8 @@ def add_new_req(
685688 remote_port = kv_transfer_params ["remote_port" ],
686689 remote_pcp_size = kv_transfer_params .get ("remote_pcp_size" , 1 ),
687690 remote_dcp_size = kv_transfer_params .get ("remote_dcp_size" , 1 ),
691+ remote_multi_nodes_meta_mapping = kv_transfer_params .get (
692+ "remote_multi_nodes_meta_mapping" , {}),
688693 )
689694
690695
@@ -772,6 +777,30 @@ def wait_for_save(self):
772777 """MooncakeConnector does not save explicitly."""
773778 pass
774779
780+ def get_handshake_metadata (self ) -> KVConnectorHandshakeMetadata | None :
781+ """
782+ Get the KVConnector handshake metadata for this connector.
783+ This metadata is used for out-of-band connector handshake
784+ between P/D workers.
785+
786+ Returns:
787+ KVConnectorHandshakeMetadata: the handshake metadata.
788+ None if no handshake metadata is available.
789+ """
790+ assert self .connector_worker is not None
791+ return self .connector_worker .xfer_handshake_metadata
792+
793+ def set_xfer_handshake_metadata (
794+ self , metadata : dict [int , KVConnectorHandshakeMetadata ]) -> None :
795+ """
796+ Set the KV connector handshake metadata for this connector.
797+
798+ Args:
799+ metadata (dict): the handshake metadata to set.
800+ """
801+ assert self .connector_scheduler is not None
802+ self .connector_scheduler .set_xfer_handshake_metadata (metadata )
803+
775804
776805class MooncakeConnectorScheduler :
777806 """Implementation of Scheduler side methods"""
@@ -805,6 +834,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
805834 self ._reqs_need_recv : dict [str , tuple [Request , list [int ]]] = {}
806835 self ._reqs_need_send : dict [str , float ] = {}
807836
837+ # master-slave meta information for cross-nodes
838+ self .multi_nodes_meta_mapping : dict [int , dict [str , Any ]] = {}
839+
808840 def get_num_new_matched_tokens (
809841 self , request : "Request" ,
810842 num_computed_tokens : int ) -> tuple [int , bool ]:
@@ -928,8 +960,23 @@ def request_finished(
928960 remote_pcp_size = self .pcp_size ,
929961 remote_dcp_size = self .dcp_size ,
930962 last_token_id = request .output_token_ids [- 1 ],
963+ remote_multi_nodes_meta_mapping = self .multi_nodes_meta_mapping ,
931964 )
932965
966+ def set_xfer_handshake_metadata (
967+ self , metadata : dict [int , KVConnectorHandshakeMetadata ]) -> None :
968+ """
969+ Set the KV connector handshake metadata for this connector.
970+
971+ Args:
972+ metadata (dict): the handshake metadata to set.
973+ """
974+ for local_rank , rank_metadata in metadata .items ():
975+ self .multi_nodes_meta_mapping [str (local_rank )] = {
976+ "host" : rank_metadata .local_ip ,
977+ "engine_id" : rank_metadata .engine_id ,
978+ }
979+
933980
934981class MooncakeConnectorWorker :
935982 """Implementation of Worker side methods"""
@@ -989,6 +1036,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
9891036 self .kv_send_thread : Optional [KVCacheSendingThread ] = None
9901037 self .kv_recv_thread : Optional [KVCacheRecvingThread ] = None
9911038
1039+ # Handshake metadata of this worker
1040+ self .xfer_handshake_metadata : MooncakeAgentMetadata | None = None
1041+
9921042 # kv_transfer variables
9931043 self .vllm_config = vllm_config
9941044 self .block_size = vllm_config .cache_config .block_size
@@ -1118,7 +1168,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
11181168 te_rpc_port = self .te_rpc_port ,
11191169 kv_caches_base_addr = kv_caches_base_addr ,
11201170 num_blocks = self .num_blocks ,
1171+ local_ip = get_ip (),
11211172 )
1173+ self .xfer_handshake_metadata = metadata
11221174
11231175 ready_event = threading .Event ()
11241176 if self .kv_role == 'kv_producer' :
@@ -1266,13 +1318,18 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12661318 continue
12671319 for i in range (self .tp_num_need_pulls ):
12681320 assert self .kv_recv_thread is not None
1321+ remote_host , remote_engine_id = self ._get_remote_host_info_by_port (
1322+ meta .remote_port ,
1323+ remote_handshake_port_list [pcp_dcp_rank ][i ],
1324+ meta .remote_host , meta .remote_engine_id ,
1325+ meta .remote_multi_nodes_meta_mapping )
12691326 self .kv_recv_thread .add_request (
12701327 request_id = req_id ,
12711328 local_block_ids = local_block_ids_list [pcp_dcp_rank ],
12721329 remote_block_ids = remote_block_ids_list [
12731330 pcp_dcp_rank ],
1274- remote_engine_id = meta . remote_engine_id ,
1275- remote_host = meta . remote_host ,
1331+ remote_engine_id = remote_engine_id ,
1332+ remote_host = remote_host ,
12761333 remote_handshake_port = remote_handshake_port_list [
12771334 pcp_dcp_rank ][i ],
12781335 offset = i ,
@@ -1287,12 +1344,16 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12871344 for x in choosen_rank_list ]
12881345 for i in range (self .tp_num_need_pulls * self ._prefill_pp_size ):
12891346 assert self .kv_recv_thread is not None
1347+ remote_host , remote_engine_id = self ._get_remote_host_info_by_port (
1348+ meta .remote_port , remote_handshake_port_list [i ][0 ],
1349+ meta .remote_host , meta .remote_engine_id ,
1350+ meta .remote_multi_nodes_meta_mapping )
12901351 self .kv_recv_thread .add_request (
12911352 request_id = req_id ,
12921353 local_block_ids = meta .local_block_ids ,
12931354 remote_block_ids = meta .remote_block_ids ,
1294- remote_engine_id = meta . remote_engine_id ,
1295- remote_host = meta . remote_host ,
1355+ remote_engine_id = remote_engine_id ,
1356+ remote_host = remote_host ,
12961357 remote_handshake_port = remote_handshake_port_list [i ][0 ],
12971358 offset = i ,
12981359 tp_num_need_pulls = self .tp_num_need_pulls ,
@@ -1307,6 +1368,18 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
13071368 else :
13081369 self .kv_send_thread .add_not_transfer_request (req_id )
13091370
1371+ def _get_remote_host_info_by_port (self , base_port : int ,
1372+ remote_handshake_port : int ,
1373+ remote_host : str , remote_engine_id : str ,
1374+ remote_multi_nodes_meta_mapping : dict ):
1375+ rank = str (remote_handshake_port - base_port )
1376+ if remote_multi_nodes_meta_mapping is None or remote_multi_nodes_meta_mapping .get (
1377+ rank , None ) is None :
1378+ return remote_host , remote_engine_id
1379+ info = remote_multi_nodes_meta_mapping [rank ]
1380+ return info .get ("host" , remote_host ), info .get ("engine_id" ,
1381+ remote_engine_id )
1382+
13101383 def _prefill_get_remote_rank (self , req_id : str ) -> List [int ]:
13111384 return sum (self ._get_remote_ranks_for_req (req_id ), [])
13121385
0 commit comments