Skip to content

Commit 407e19d

Browse files
committed
Address PR review: signature, scope, types
- Restore BackendProtocol.endpoints_to_processes signature by adding port_allocator parameter and threading it through to the topology helper so per-job port jitter applies uniformly. - Drop unrelated VLLM_NIXL_SIDE_CHANNEL_HOST / get_hostname_ip import that was not mooncake-related. - Widen store_config to dict[str, Any] (matches MooncakeStoreConfig's mix of str/int/size-string fields).
1 parent f7a1016 commit 407e19d

1 file changed

Lines changed: 13 additions & 10 deletions

File tree

src/srtctl/backends/vllm.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
if TYPE_CHECKING:
3333
from srtctl.backends.base import SrunConfig
3434
from srtctl.core.runtime import RuntimeContext
35-
from srtctl.core.topology import Endpoint, Process
35+
from srtctl.core.topology import Endpoint, NodePortAllocator, Process
3636

3737
# Type alias for worker modes
3838
WorkerMode = Literal["prefill", "decode", "agg"]
@@ -92,7 +92,12 @@ class VLLMMooncakeKVStoreConfig:
9292

9393
container: str | None = None
9494
env: dict[str, str] = field(default_factory=dict)
95-
store_config: dict[str, str] | None = None
95+
# ``store_config`` values are JSON-serialized into MOONCAKE_CONFIG_PATH and
96+
# parsed by vLLM's ``MooncakeStoreConfig`` dataclass — fields are a mix of
97+
# str (e.g. ``protocol``), int (e.g. ``port``), and human-readable sizes
98+
# (e.g. ``"4GB"``). Type as ``dict[str, Any]`` to avoid forcing users to
99+
# quote numeric values.
100+
store_config: dict[str, Any] | None = None
96101

97102
Schema: ClassVar[builtins.type[Schema]] = Schema
98103

@@ -201,16 +206,12 @@ def get_process_environment(self, process: Process) -> dict[str, str]:
201206
vLLM with dynamo requires unique ports for each worker:
202207
- DYN_VLLM_KV_EVENT_PORT: ZMQ port for KV events publishing
203208
- VLLM_NIXL_SIDE_CHANNEL_PORT: Port for NIXL side channel transfers
204-
- VLLM_NIXL_SIDE_CHANNEL_HOST: Routable IP for NIXL side channel (not 0.0.0.0/localhost)
205209
"""
206-
from srtctl.core.slurm import get_hostname_ip
207-
208210
env: dict[str, str] = {}
209211
if process.kv_events_port is not None:
210212
env["DYN_VLLM_KV_EVENT_PORT"] = str(process.kv_events_port)
211213
if process.nixl_port is not None:
212214
env["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(process.nixl_port)
213-
env["VLLM_NIXL_SIDE_CHANNEL_HOST"] = get_hostname_ip(process.node)
214215
return env
215216

216217
def get_mooncake_worker_env(self, infra_node_ip: str, local_hostname: str) -> dict[str, str]:
@@ -238,7 +239,7 @@ def get_mooncake_worker_env(self, infra_node_ip: str, local_hostname: str) -> di
238239
"MOONCAKE_CONFIG_PATH": MOONCAKE_STORE_CONFIG_CONTAINER_PATH,
239240
}
240241

241-
def build_mooncake_store_config(self, infra_node_ip: str) -> dict[str, str]:
242+
def build_mooncake_store_config(self, infra_node_ip: str) -> dict[str, Any]:
242243
"""Build the JSON payload for vLLM's ``MooncakeStoreConfig.load_from_env()``.
243244
244245
Keys map 1:1 to vLLM's ``MooncakeStoreConfig`` dataclass. Values come
@@ -247,7 +248,7 @@ def build_mooncake_store_config(self, infra_node_ip: str) -> dict[str, str]:
247248
infra node IP (any user-provided value is overridden — the user can't
248249
know the infra IP at config time).
249250
"""
250-
user_cfg: dict[str, str] = {}
251+
user_cfg: dict[str, Any] = {}
251252
if self.mooncake_kv_store is not None and self.mooncake_kv_store.store_config:
252253
user_cfg = dict(self.mooncake_kv_store.store_config)
253254

@@ -313,6 +314,7 @@ def endpoints_to_processes(
313314
self,
314315
endpoints: list[Endpoint],
315316
base_sys_port: int = 8081,
317+
port_allocator: NodePortAllocator | None = None,
316318
) -> list[Process]:
317319
"""Convert endpoints to processes.
318320
@@ -326,12 +328,13 @@ def endpoints_to_processes(
326328

327329
if not has_dp_mode:
328330
# Standard TP mode: one process per node
329-
return endpoints_to_processes(endpoints, base_sys_port=base_sys_port)
331+
return endpoints_to_processes(endpoints, base_sys_port=base_sys_port, port_allocator=port_allocator)
330332

331333
# DP+EP mode: one process per GPU
332334
processes: list[Process] = []
333335
current_sys_port = base_sys_port
334-
port_allocator = NodePortAllocator()
336+
if port_allocator is None:
337+
port_allocator = NodePortAllocator()
335338

336339
for endpoint in endpoints:
337340
if not self._is_dp_mode(endpoint.mode):

0 commit comments

Comments
 (0)