Skip to content

Commit 0e1105f

Browse files
committed
use open port at runtime
Signed-off-by: wuhang <wuhang6@huawei.com>
1 parent 74efb05 commit 0e1105f

File tree

4 files changed

+127
-52
lines changed

4 files changed

+127
-52
lines changed

vllm_omni/entrypoints/cli/serve.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
load_stage_configs_from_yaml,
3232
resolve_model_config_path,
3333
)
34-
from vllm_omni.entrypoints.zmq_utils import ZmqQueueSpec
3534

3635
logger = init_logger(__name__)
3736

@@ -356,13 +355,40 @@ def run_headless(args: argparse.Namespace) -> None:
356355

357356
omni_master_address = getattr(args, "omni_master_address", None) or "127.0.0.1"
358357
omni_master_port = int(getattr(args, "omni_master_port", 5555) or 5555)
359-
base_port = omni_master_port + 1
360-
in_endpoint = f"tcp://{omni_master_address}:{base_port + single_stage_id * 2}"
361-
out_endpoint = f"tcp://{omni_master_address}:{base_port + single_stage_id * 2 + 1}"
362358

363-
in_q_spec = ZmqQueueSpec(endpoint=in_endpoint, socket_type=zmq.PULL, bind=False)
364-
out_q_spec = ZmqQueueSpec(endpoint=out_endpoint, socket_type=zmq.PUSH, bind=False)
359+
# Perform handshake with orchestrator to get dynamically allocated endpoints
365360
zmq_ctx = zmq.Context()
361+
handshake_socket = zmq_ctx.socket(zmq.REQ)
362+
handshake_socket.linger = 0
363+
handshake_endpoint = f"tcp://{omni_master_address}:{omni_master_port}"
364+
365+
try:
366+
handshake_socket.connect(handshake_endpoint)
367+
handshake_msg = {"type": "handshake", "stage_id": single_stage_id}
368+
handshake_socket.send_pyobj(handshake_msg)
369+
370+
# Wait for response with timeout
371+
if handshake_socket.poll(timeout=10000): # 10 second timeout
372+
response = handshake_socket.recv_pyobj()
373+
if not response.get("ok", False):
374+
error_msg = response.get("error", "unknown error")
375+
raise RuntimeError(f"Handshake failed for stage-{single_stage_id}: {error_msg}")
376+
377+
in_q_spec = response.get("in_spec")
378+
out_q_spec = response.get("out_spec")
379+
380+
if in_q_spec is None or out_q_spec is None:
381+
raise RuntimeError(f"Handshake response missing specs for stage-{single_stage_id}")
382+
383+
logger.info(
384+
f"[Headless] Stage-{single_stage_id} received endpoints via handshake: "
385+
f"in={in_q_spec.endpoint}, out={out_q_spec.endpoint}"
386+
)
387+
else:
388+
raise TimeoutError(f"Handshake timeout for stage-{single_stage_id} at {handshake_endpoint}")
389+
390+
finally:
391+
handshake_socket.close(0)
366392
in_q = None
367393
out_q = None
368394

vllm_omni/entrypoints/omni.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tqdm.auto import tqdm
1919
from vllm import SamplingParams
2020
from vllm.logger import init_logger
21+
from vllm.utils.network_utils import get_open_port
2122

2223
from vllm_omni.distributed.omni_connectors import (
2324
get_stage_connector_config,
@@ -163,7 +164,7 @@ def __init__(self, model: str, **kwargs: Any) -> None:
163164
self._zmq_handshake_socket: zmq.Socket | None = None
164165
self._zmq_handshake_thread: threading.Thread | None = None
165166
self._zmq_handshake_stop: threading.Event | None = None
166-
self._zmq_handshake_specs: dict[int, ZmqQueueSpec] = {}
167+
self._zmq_handshake_specs: dict[int, tuple[ZmqQueueSpec, ZmqQueueSpec]] = {}
167168
self._zmq_handshake_seen: set[int] = set()
168169
self._total_stage_count: int = 0
169170
self._single_stage_id: int | None = None
@@ -264,6 +265,9 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None:
264265

265266
base_engine_args = {"tokenizer": tokenizer} if tokenizer is not None else None
266267

268+
# TODO(wuhang):
269+
# Remove kwargs as parameters in the future.
270+
# Use dataclass directly.
267271
parallel_keys = [
268272
"tensor_parallel_size",
269273
"pipeline_parallel_size",
@@ -329,8 +333,6 @@ def _build_stage(idx_cfg: tuple[int, Any]) -> tuple[int, OmniStage]:
329333
idx, cfg = idx_cfg
330334
return idx, OmniStage(cfg, stage_init_timeout=stage_init_timeout)
331335

332-
logger.info(f"====== stage configs:\n{pformat(OmegaConf.to_container(self.stage_configs))}")
333-
334336
with ThreadPoolExecutor(max_workers=min(len(self.stage_configs), max(1, os.cpu_count() or 1))) as executor:
335337
futures = [executor.submit(_build_stage, (idx, cfg)) for idx, cfg in enumerate(self.stage_configs)]
336338
results: list[tuple[int, OmniStage]] = []
@@ -373,17 +375,33 @@ def _start_stages(self, model: str) -> None:
373375
if self.worker_backend != "ray":
374376
self._ensure_zmq_handshake_server()
375377

376-
base_port = int(self._zmq_master_port or 5555) + 1
378+
# Pre-allocate ports for all stages using dynamic port allocation
379+
stage_ports: dict[int, tuple[int, int]] = {}
377380
if self.worker_backend != "ray":
378-
self._zmq_handshake_specs = {}
379381
total_stages = self._total_stage_count or len(self.stage_list)
380382
for sid in range(total_stages):
381-
out_endpoint = f"tcp://{self._zmq_master_address}:{base_port + sid * 2 + 1}"
382-
self._zmq_handshake_specs[sid] = ZmqQueueSpec(
383+
in_port = get_open_port()
384+
out_port = get_open_port()
385+
stage_ports[sid] = (in_port, out_port)
386+
logger.debug(f"[{self._name}] Allocated ports for stage-{sid}: in={in_port}, out={out_port}")
387+
388+
# Build handshake specs with allocated ports
389+
self._zmq_handshake_specs = {}
390+
for sid in range(total_stages):
391+
in_port, out_port = stage_ports[sid]
392+
in_endpoint = f"tcp://{self._zmq_master_address}:{in_port}"
393+
out_endpoint = f"tcp://{self._zmq_master_address}:{out_port}"
394+
in_spec = ZmqQueueSpec(
395+
endpoint=in_endpoint,
396+
socket_type=zmq.PULL,
397+
bind=False,
398+
)
399+
out_spec = ZmqQueueSpec(
383400
endpoint=out_endpoint,
384401
socket_type=zmq.PUSH,
385402
bind=False,
386403
)
404+
self._zmq_handshake_specs[sid] = (in_spec, out_spec)
387405

388406
for stage_id, stage in enumerate[OmniStage](self.stage_list):
389407
if self.worker_backend == "ray":
@@ -392,8 +410,9 @@ def _start_stages(self, model: str) -> None:
392410
in_spec = None
393411
out_spec = None
394412
else:
395-
in_endpoint = f"tcp://{self._zmq_master_address}:{base_port + stage_id * 2}"
396-
out_endpoint = f"tcp://{self._zmq_master_address}:{base_port + stage_id * 2 + 1}"
413+
in_port, out_port = stage_ports[stage_id]
414+
in_endpoint = f"tcp://{self._zmq_master_address}:{in_port}"
415+
out_endpoint = f"tcp://{self._zmq_master_address}:{out_port}"
397416
in_q = ZmqQueue(self._zmq_ctx, zmq.PUSH, bind=in_endpoint)
398417
out_q = ZmqQueue(self._zmq_ctx, zmq.PULL, bind=out_endpoint)
399418
in_spec = ZmqQueueSpec(endpoint=in_endpoint, socket_type=zmq.PULL, bind=False)
@@ -635,12 +654,17 @@ def _serve() -> None:
635654
try:
636655
if isinstance(msg, dict) and msg.get("type") == "handshake":
637656
stage_id = int(msg.get("stage_id"))
638-
out_spec = self._zmq_handshake_specs.get(stage_id)
639-
if out_spec is None:
657+
specs = self._zmq_handshake_specs.get(stage_id)
658+
if specs is None:
640659
resp = {"ok": False, "error": f"unknown stage_id: {stage_id}"}
641660
else:
642661
self._zmq_handshake_seen.add(stage_id)
643-
resp = {"ok": True, "out_spec": out_spec}
662+
in_spec, out_spec = specs
663+
resp = {
664+
"ok": True,
665+
"in_spec": in_spec,
666+
"out_spec": out_spec,
667+
}
644668
logger.info(
645669
"[%s] Handshake received from stage-%s",
646670
self._name,

vllm_omni/entrypoints/omni_stage.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
ZmqQueue,
5454
ZmqQueueSpec,
5555
create_zmq_queue,
56-
request_zmq_out_spec,
5756
)
5857
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams, OmniTokensPrompt
5958
from vllm_omni.outputs import OmniRequestOutput
@@ -326,9 +325,6 @@ def init_stage_worker(
326325
"engine_input_source": self.engine_input_source,
327326
"final_output": self.final_output,
328327
"final_output_type": self.final_output_type,
329-
"zmq_master_address": self._zmq_master_address,
330-
"zmq_master_port": self._zmq_master_port,
331-
"zmq_use_handshake": self._zmq_use_handshake,
332328
}
333329
try:
334330
old_env = os.environ.get("VLLM_LOGGING_PREFIX")
@@ -576,21 +572,6 @@ def _stage_worker(
576572
if stage_type != "diffusion":
577573
_resolve_worker_cls(engine_args)
578574

579-
zmq_master_address = stage_payload.get("zmq_master_address")
580-
zmq_master_port = stage_payload.get("zmq_master_port")
581-
use_zmq_handshake = bool(stage_payload.get("zmq_use_handshake", False))
582-
583-
if use_zmq_handshake and zmq_master_address and zmq_master_port:
584-
try:
585-
master_endpoint = f"tcp://{zmq_master_address}:{int(zmq_master_port)}"
586-
out_q = request_zmq_out_spec(master_endpoint, stage_id)
587-
except Exception as e:
588-
logger.warning(
589-
"[Stage-%s] ZMQ handshake failed, falling back to provided out_q spec: %s",
590-
stage_id,
591-
e,
592-
)
593-
594575
# Resolve ZMQ queue specs if needed
595576
zmq_ctx = None
596577
if isinstance(in_q, ZmqQueueSpec) or isinstance(out_q, ZmqQueueSpec):
@@ -1145,24 +1126,10 @@ async def _stage_worker_async(
11451126
stage_type = stage_payload.get("stage_type", "llm")
11461127
final_output = stage_payload.get("final_output", False)
11471128
final_output_type = stage_payload.get("final_output_type", None)
1148-
zmq_master_address = stage_payload.get("zmq_master_address")
1149-
zmq_master_port = stage_payload.get("zmq_master_port")
1150-
use_zmq_handshake = bool(stage_payload.get("zmq_use_handshake", False))
11511129

11521130
if stage_type != "diffusion":
11531131
_resolve_worker_cls(engine_args)
11541132

1155-
if use_zmq_handshake and zmq_master_address and zmq_master_port:
1156-
try:
1157-
master_endpoint = f"tcp://{zmq_master_address}:{int(zmq_master_port)}"
1158-
out_q = request_zmq_out_spec(master_endpoint, stage_id)
1159-
except Exception as e:
1160-
logger.warning(
1161-
"[Stage-%s] ZMQ handshake failed, falling back to provided out_q spec: %s",
1162-
stage_id,
1163-
e,
1164-
)
1165-
11661133
# Resolve ZMQ queue specs if needed
11671134
zmq_ctx = None
11681135
if isinstance(in_q, ZmqQueueSpec) or isinstance(out_q, ZmqQueueSpec):

vllm_omni/entrypoints/zmq_utils.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ def request_zmq_out_spec(
9797
*,
9898
timeout_ms: int = 30000,
9999
) -> ZmqQueueSpec:
100-
"""Request the output queue spec for a stage via the master handshake."""
100+
"""Request the output queue spec for a stage via the master handshake.
101+
102+
Note: This function only returns out_spec for backward compatibility.
103+
Use request_zmq_specs() to get both in_spec and out_spec.
104+
"""
101105

102106
ctx = zmq.Context.instance()
103107
sock = ctx.socket(zmq.REQ)
@@ -124,3 +128,57 @@ def request_zmq_out_spec(
124128
if isinstance(out_spec, dict):
125129
return ZmqQueueSpec(**out_spec)
126130
raise RuntimeError(f"Invalid out_spec type: {type(out_spec)}")
131+
132+
133+
def request_zmq_specs(
134+
master_endpoint: str,
135+
stage_id: int,
136+
*,
137+
timeout_ms: int = 30000,
138+
) -> tuple[ZmqQueueSpec, ZmqQueueSpec]:
139+
"""Request both input and output queue specs for a stage via the master handshake.
140+
141+
Returns:
142+
tuple[ZmqQueueSpec, ZmqQueueSpec]: A tuple of (in_spec, out_spec)
143+
"""
144+
145+
ctx = zmq.Context.instance()
146+
sock = ctx.socket(zmq.REQ)
147+
sock.linger = 0
148+
sock.rcvtimeo = int(timeout_ms)
149+
sock.sndtimeo = int(timeout_ms)
150+
sock.connect(master_endpoint)
151+
try:
152+
sock.send_pyobj({"type": "handshake", "stage_id": int(stage_id)})
153+
resp = sock.recv_pyobj()
154+
finally:
155+
try:
156+
sock.close(0)
157+
except Exception:
158+
pass
159+
160+
if not isinstance(resp, dict):
161+
raise RuntimeError(f"Invalid handshake response: {type(resp)}")
162+
163+
in_spec_data = resp.get("in_spec")
164+
out_spec_data = resp.get("out_spec")
165+
166+
if in_spec_data is None or out_spec_data is None:
167+
raise RuntimeError(f"Handshake response missing specs: {resp}")
168+
169+
# Convert to ZmqQueueSpec if needed
170+
if isinstance(in_spec_data, ZmqQueueSpec):
171+
in_spec = in_spec_data
172+
elif isinstance(in_spec_data, dict):
173+
in_spec = ZmqQueueSpec(**in_spec_data)
174+
else:
175+
raise RuntimeError(f"Invalid in_spec type: {type(in_spec_data)}")
176+
177+
if isinstance(out_spec_data, ZmqQueueSpec):
178+
out_spec = out_spec_data
179+
elif isinstance(out_spec_data, dict):
180+
out_spec = ZmqQueueSpec(**out_spec_data)
181+
else:
182+
raise RuntimeError(f"Invalid out_spec type: {type(out_spec_data)}")
183+
184+
return in_spec, out_spec

0 commit comments

Comments
 (0)