|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | 4 | import os |
| 5 | +import socket |
5 | 6 | from collections.abc import Callable |
6 | 7 | from typing import TYPE_CHECKING, Any, Literal, overload |
7 | 8 |
|
@@ -266,33 +267,9 @@ class is dynamically inherited by the worker class. This is used to inject |
266 | 267 | Set to be private as it's not intended to be configured by users. |
267 | 268 | """ |
268 | 269 |
|
269 | | - _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) |
270 | | - """List of open ports for stateless DP groups when enable_elastic_ep is True. |
271 | | - Set to be private as it's not intended to be configured by users. |
272 | | - It is a list of list[int], with each inner list contains a set of 3 ports |
273 | | - to be used for setting up the stateless CPU/device/TCPStore groups |
274 | | - in StatelessGroupCoordinator. The number of inner lists is equal to |
275 | | - the number of DP groups, |
276 | | - i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size, |
277 | | - and len(self._stateless_dp_group_port_list[i]) == 3 for all i. |
278 | | - """ |
279 | | - |
280 | | - _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list) |
281 | | - """List of open ports for stateless EP groups when enable_elastic_ep is True. |
282 | | - Set to be private as it's not intended to be configured by users. |
283 | | - len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size, |
284 | | - """ |
285 | | - |
286 | | - _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list) |
287 | | - """List of open ports for stateless EPLB groups when enable_elastic_ep is True. |
288 | | - Same topology as EP but separate NCCL communicator to avoid deadlocks. |
289 | | - """ |
290 | | - |
291 | | - _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list) |
292 | | - """List of open ports for stateless world group when enable_elastic_ep is True. |
293 | | - Set to be private as it's not intended to be configured by users. |
294 | | - len(self._stateless_world_group_port_list) == 1, |
295 | | - """ |
| 270 | + _coord_store_port: int = 0 |
| 271 | + """Port of the coordination TCPStore. Can be set by the API server; workers |
| 272 | + connect as clients to exchange self-picked group ports at runtime.""" |
296 | 273 |
|
297 | 274 | decode_context_parallel_size: int = 1 |
298 | 275 | """Number of decode context parallel groups, because the world size does |
@@ -465,65 +442,32 @@ def get_next_dp_init_port(self) -> int: |
465 | 442 |
|
466 | 443 | return answer |
467 | 444 |
|
468 | | - def allocate_elastic_ep_ports(self) -> None: |
469 | | - """Allocate all ports for elastic EP (stateless groups + DP master). |
| 445 | + def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]: |
| 446 | + """Return ``(port, listen_socket)`` for DP group init. |
470 | 447 |
|
471 | | - Must be called AFTER ray.init() so that ports claimed by Ray's |
472 | | - idle worker pool are already in use and won't be returned by |
473 | | - get_open_ports_list(). |
| 448 | + With a coord store, rank 0 binds a socket and publishes the port; |
| 449 | + others read it. Without one, pops a pre-allocated port and |
| 450 | + returns ``listen_socket=None``. |
474 | 451 | """ |
475 | | - if not self.enable_elastic_ep: |
476 | | - return |
477 | | - if self._stateless_world_group_port_list: |
478 | | - return |
479 | | - |
480 | | - num_world_groups = 1 |
481 | | - dp_size = self.data_parallel_size |
482 | | - ep_size = self.data_parallel_size * self.world_size_across_dp |
483 | | - num_dp_groups = max(1, self.world_size_across_dp // dp_size) |
484 | | - num_ep_groups = max(1, self.world_size_across_dp // ep_size) |
485 | | - num_eplb_groups = num_ep_groups |
486 | | - total_stateless_ports = ( |
487 | | - num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups |
488 | | - ) * 3 |
489 | | - num_dp_master_ports = 5 |
490 | | - |
491 | | - all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports) |
492 | | - |
493 | | - self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:] |
494 | | - self.data_parallel_master_port = self._data_parallel_master_port_list.pop() |
495 | | - all_ports = all_ports[:-num_dp_master_ports] |
496 | | - |
497 | | - self._stateless_world_group_port_list = [ |
498 | | - all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) |
499 | | - ] |
500 | | - start_idx = num_world_groups * 3 |
501 | | - self._stateless_dp_group_port_list = [ |
502 | | - all_ports[i : i + 3] |
503 | | - for i in range(start_idx, start_idx + num_dp_groups * 3, 3) |
504 | | - ] |
505 | | - start_idx += num_dp_groups * 3 |
506 | | - self._stateless_ep_group_port_list = [ |
507 | | - all_ports[i : i + 3] |
508 | | - for i in range(start_idx, start_idx + num_ep_groups * 3, 3) |
509 | | - ] |
510 | | - start_idx += num_ep_groups * 3 |
511 | | - self._stateless_eplb_group_port_list = [ |
512 | | - all_ports[i : i + 3] |
513 | | - for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) |
514 | | - ] |
515 | | - |
516 | | - def get_next_stateless_world_group_port(self) -> list[int]: |
517 | | - return self._stateless_world_group_port_list.pop() |
518 | | - |
519 | | - def get_next_stateless_dp_group_port(self) -> list[int]: |
520 | | - return self._stateless_dp_group_port_list.pop() |
521 | | - |
522 | | - def get_next_stateless_ep_group_port(self) -> list[int]: |
523 | | - return self._stateless_ep_group_port_list.pop() |
524 | | - |
525 | | - def get_next_stateless_eplb_group_port(self) -> list[int]: |
526 | | - return self._stateless_eplb_group_port_list.pop() |
| 452 | + if not self._coord_store_port: |
| 453 | + return self.get_next_dp_init_port(), None |
| 454 | + |
| 455 | + from vllm.distributed.utils import get_cached_tcp_store_client |
| 456 | + |
| 457 | + store = get_cached_tcp_store_client( |
| 458 | + self.data_parallel_master_ip, self._coord_store_port |
| 459 | + ) |
| 460 | + |
| 461 | + key = "dp_master_port" |
| 462 | + if self.data_parallel_rank == 0: |
| 463 | + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 464 | + s.bind((self.data_parallel_master_ip, 0)) |
| 465 | + s.listen() |
| 466 | + port = s.getsockname()[1] |
| 467 | + store.set(key, str(port).encode()) |
| 468 | + return port, s |
| 469 | + else: |
| 470 | + return int(store.get(key).decode()), None |
527 | 471 |
|
528 | 472 | @overload |
529 | 473 | def stateless_init_dp_group( |
@@ -553,14 +497,16 @@ def stateless_init_dp_group( |
553 | 497 | last_exc: Exception | None = None |
554 | 498 | for _ in range(max_retries): |
555 | 499 | try: |
| 500 | + port, listen_socket = self._pick_stateless_dp_port() |
556 | 501 | # use gloo since the engine process might not have cuda device |
557 | 502 | return stateless_init_torch_distributed_process_group( |
558 | 503 | self.data_parallel_master_ip, |
559 | | - self.get_next_dp_init_port(), |
| 504 | + port, |
560 | 505 | self.data_parallel_rank, |
561 | 506 | self.data_parallel_size, |
562 | 507 | backend="gloo", |
563 | 508 | return_store=return_store, |
| 509 | + listen_socket=listen_socket, |
564 | 510 | ) |
565 | 511 | except DistNetworkError as e: |
566 | 512 | # We only want to retry when the root cause is EADDRINUSE. |
|
0 commit comments