Skip to content

Commit 4d3ddb2

Browse files
itayalroykhairulkabir1661
authored andcommitted
elastic_ep: Fix stateless group port races (vllm-project#36330)
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
1 parent 2152b36 commit 4d3ddb2

12 files changed

Lines changed: 224 additions & 225 deletions

File tree

.buildkite/test_areas/expert_parallelism.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ steps:
2424

2525
- label: Elastic EP Scaling Test
2626
timeout_in_minutes: 20
27-
device: b200
28-
optional: true
27+
device: h100
2928
working_dir: "/vllm-workspace/tests"
3029
num_devices: 4
3130
source_file_dependencies:

vllm/config/parallel.py

Lines changed: 31 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import os
5+
import socket
56
from collections.abc import Callable
67
from typing import TYPE_CHECKING, Any, Literal, overload
78

@@ -266,33 +267,9 @@ class is dynamically inherited by the worker class. This is used to inject
266267
Set to be private as it's not intended to be configured by users.
267268
"""
268269

269-
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
270-
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
271-
Set to be private as it's not intended to be configured by users.
272-
It is a list of list[int], with each inner list contains a set of 3 ports
273-
to be used for setting up the stateless CPU/device/TCPStore groups
274-
in StatelessGroupCoordinator. The number of inner lists is equal to
275-
the number of DP groups,
276-
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
277-
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
278-
"""
279-
280-
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
281-
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
282-
Set to be private as it's not intended to be configured by users.
283-
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
284-
"""
285-
286-
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
287-
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
288-
Same topology as EP but separate NCCL communicator to avoid deadlocks.
289-
"""
290-
291-
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
292-
"""List of open ports for stateless world group when enable_elastic_ep is True.
293-
Set to be private as it's not intended to be configured by users.
294-
len(self._stateless_world_group_port_list) == 1,
295-
"""
270+
_coord_store_port: int = 0
271+
"""Port of the coordination TCPStore. Can be set by the API server; workers
272+
connect as clients to exchange self-picked group ports at runtime."""
296273

297274
decode_context_parallel_size: int = 1
298275
"""Number of decode context parallel groups, because the world size does
@@ -465,65 +442,32 @@ def get_next_dp_init_port(self) -> int:
465442

466443
return answer
467444

468-
def allocate_elastic_ep_ports(self) -> None:
469-
"""Allocate all ports for elastic EP (stateless groups + DP master).
445+
def _pick_stateless_dp_port(self) -> tuple[int, socket.socket | None]:
446+
"""Return ``(port, listen_socket)`` for DP group init.
470447
471-
Must be called AFTER ray.init() so that ports claimed by Ray's
472-
idle worker pool are already in use and won't be returned by
473-
get_open_ports_list().
448+
With a coord store, rank 0 binds a socket and publishes the port;
449+
others read it. Without one, pops a pre-allocated port and
450+
returns ``listen_socket=None``.
474451
"""
475-
if not self.enable_elastic_ep:
476-
return
477-
if self._stateless_world_group_port_list:
478-
return
479-
480-
num_world_groups = 1
481-
dp_size = self.data_parallel_size
482-
ep_size = self.data_parallel_size * self.world_size_across_dp
483-
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
484-
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
485-
num_eplb_groups = num_ep_groups
486-
total_stateless_ports = (
487-
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
488-
) * 3
489-
num_dp_master_ports = 5
490-
491-
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
492-
493-
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
494-
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
495-
all_ports = all_ports[:-num_dp_master_ports]
496-
497-
self._stateless_world_group_port_list = [
498-
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
499-
]
500-
start_idx = num_world_groups * 3
501-
self._stateless_dp_group_port_list = [
502-
all_ports[i : i + 3]
503-
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
504-
]
505-
start_idx += num_dp_groups * 3
506-
self._stateless_ep_group_port_list = [
507-
all_ports[i : i + 3]
508-
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
509-
]
510-
start_idx += num_ep_groups * 3
511-
self._stateless_eplb_group_port_list = [
512-
all_ports[i : i + 3]
513-
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
514-
]
515-
516-
def get_next_stateless_world_group_port(self) -> list[int]:
517-
return self._stateless_world_group_port_list.pop()
518-
519-
def get_next_stateless_dp_group_port(self) -> list[int]:
520-
return self._stateless_dp_group_port_list.pop()
521-
522-
def get_next_stateless_ep_group_port(self) -> list[int]:
523-
return self._stateless_ep_group_port_list.pop()
524-
525-
def get_next_stateless_eplb_group_port(self) -> list[int]:
526-
return self._stateless_eplb_group_port_list.pop()
452+
if not self._coord_store_port:
453+
return self.get_next_dp_init_port(), None
454+
455+
from vllm.distributed.utils import get_cached_tcp_store_client
456+
457+
store = get_cached_tcp_store_client(
458+
self.data_parallel_master_ip, self._coord_store_port
459+
)
460+
461+
key = "dp_master_port"
462+
if self.data_parallel_rank == 0:
463+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
464+
s.bind((self.data_parallel_master_ip, 0))
465+
s.listen()
466+
port = s.getsockname()[1]
467+
store.set(key, str(port).encode())
468+
return port, s
469+
else:
470+
return int(store.get(key).decode()), None
527471

528472
@overload
529473
def stateless_init_dp_group(
@@ -553,14 +497,16 @@ def stateless_init_dp_group(
553497
last_exc: Exception | None = None
554498
for _ in range(max_retries):
555499
try:
500+
port, listen_socket = self._pick_stateless_dp_port()
556501
# use gloo since the engine process might not have cuda device
557502
return stateless_init_torch_distributed_process_group(
558503
self.data_parallel_master_ip,
559-
self.get_next_dp_init_port(),
504+
port,
560505
self.data_parallel_rank,
561506
self.data_parallel_size,
562507
backend="gloo",
563508
return_store=return_store,
509+
listen_socket=listen_socket,
564510
)
565511
except DistNetworkError as e:
566512
# We only want to retry when the root cause is EADDRINUSE.

vllm/distributed/elastic_ep/elastic_execute.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,8 @@ def create_standby_groups(
162162
new_dp_size=new_dp_size,
163163
new_world_size_across_dp=new_world_size_across_dp,
164164
master_ip=reconfig_request.new_data_parallel_master_ip,
165-
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
166-
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
167-
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
168-
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
165+
coord_store_port=reconfig_request.coord_store_port,
166+
enable_eplb=updated_config.parallel_config.enable_eplb,
169167
)
170168
self.worker.model_runner.eep_eplb_suppressed = True
171169
standby_ep_group = get_standby_ep_group()

vllm/distributed/elastic_ep/elastic_state.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -563,15 +563,4 @@ def _update_parallel_config(self):
563563
parallel_config._data_parallel_master_port_list = (
564564
reconfig_request.new_data_parallel_master_port_list
565565
)
566-
parallel_config._stateless_world_group_port_list = (
567-
reconfig_request.new_stateless_world_group_port_list
568-
)
569-
parallel_config._stateless_dp_group_port_list = (
570-
reconfig_request.new_stateless_dp_group_port_list
571-
)
572-
parallel_config._stateless_ep_group_port_list = (
573-
reconfig_request.new_stateless_ep_group_port_list
574-
)
575-
parallel_config._stateless_eplb_group_port_list = (
576-
reconfig_request.new_stateless_eplb_group_port_list
577-
)
566+
parallel_config._coord_store_port = reconfig_request.coord_store_port

vllm/distributed/elastic_ep/standby_state.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ def create_standby_groups(
3838
new_dp_size: int,
3939
new_world_size_across_dp: int,
4040
master_ip: str,
41-
world_group_ports: list[list[int]],
42-
dp_group_ports: list[list[int]],
43-
ep_group_ports: list[list[int]],
44-
eplb_group_ports: list[list[int]] | None = None,
41+
coord_store_port: int,
42+
enable_eplb: bool = True,
4543
backend: str | None = None,
4644
) -> None:
4745
global \
@@ -51,19 +49,23 @@ def create_standby_groups(
5149
_STANDBY_EP, \
5250
_STANDBY_EPLB
5351

52+
from vllm.distributed.utils import get_cached_tcp_store_client
53+
5454
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
5555
world_group = get_world_group()
5656
assert isinstance(world_group, StatelessGroupCoordinator)
5757
backend = backend or world_group.backend
5858

59+
coord_store = get_cached_tcp_store_client(master_ip, coord_store_port)
60+
5961
standby_world_ranks = [list(range(new_world_size_across_dp))]
6062
_STANDBY_WORLD = _init_stateless_group(
6163
standby_world_ranks,
6264
"world",
63-
world_group_ports,
6465
master_ip,
6566
backend,
6667
use_device_communicator=False,
68+
coord_store=coord_store,
6769
)
6870
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
6971

@@ -76,20 +78,24 @@ def create_standby_groups(
7678
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
7779
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
7880
_STANDBY_DP = _init_stateless_group(
79-
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
81+
standby_dp_ranks, "dp", master_ip, backend, coord_store=coord_store
8082
)
8183

8284
standby_ep_ranks = (
8385
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
8486
)
8587
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
8688
_STANDBY_EP = _init_stateless_group(
87-
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
89+
standby_ep_ranks, "ep", master_ip, backend, coord_store=coord_store
8890
)
8991

90-
if eplb_group_ports is not None:
92+
if enable_eplb:
9193
_STANDBY_EPLB = _init_stateless_group(
92-
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
94+
standby_ep_ranks,
95+
"eplb",
96+
master_ip,
97+
backend,
98+
coord_store=coord_store,
9399
)
94100

95101

vllm/distributed/parallel_state.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,16 @@
4040
import torch.distributed
4141
import torch.distributed._functional_collectives as funcol
4242
import torch.distributed._symmetric_memory
43-
from torch.distributed import Backend, ProcessGroup
43+
from torch.distributed import Backend, ProcessGroup, Store
4444

4545
import vllm.envs as envs
4646
from vllm.distributed.device_communicators.base_device_communicator import (
4747
DeviceCommunicatorBase,
4848
)
49-
from vllm.distributed.utils import StatelessProcessGroup
49+
from vllm.distributed.utils import (
50+
StatelessProcessGroup,
51+
get_cached_tcp_store_client,
52+
)
5053
from vllm.logger import init_logger
5154
from vllm.utils.import_utils import resolve_obj_by_qualname
5255
from vllm.utils.network_utils import get_distributed_init_method
@@ -1164,9 +1167,9 @@ def init_model_parallel_group(
11641167
def _init_stateless_group(
11651168
group_ranks: list[list[int]],
11661169
group_name: str,
1167-
group_ports: list[list[int]],
11681170
host: str,
11691171
backend: str,
1172+
coord_store: Store,
11701173
use_device_communicator: bool = True,
11711174
) -> "StatelessGroupCoordinator":
11721175
"""Create a StatelessGroupCoordinator with the given parameters."""
@@ -1180,7 +1183,7 @@ def _init_stateless_group(
11801183
use_device_communicator=use_device_communicator,
11811184
group_name=group_name,
11821185
host=host,
1183-
group_ports=group_ports,
1186+
coord_store=coord_store,
11841187
global_rank=world.rank,
11851188
global_world_size=world.world_size,
11861189
)
@@ -1321,15 +1324,17 @@ def _init_elastic_ep_world(
13211324
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
13221325
if global_rank in all_ranks:
13231326
group_ranks = [all_ranks]
1324-
group_ports = [parallel_config.get_next_stateless_world_group_port()]
1327+
coord_store = get_cached_tcp_store_client(
1328+
parallel_config.data_parallel_master_ip, parallel_config._coord_store_port
1329+
)
13251330
world = StatelessGroupCoordinator(
13261331
group_ranks=group_ranks,
13271332
local_rank=local_rank,
13281333
torch_distributed_backend=backend,
13291334
use_device_communicator=False,
13301335
group_name="world",
13311336
host=parallel_config.data_parallel_master_ip,
1332-
group_ports=group_ports,
1337+
coord_store=coord_store,
13331338
global_rank=global_rank,
13341339
global_world_size=global_world_size,
13351340
)
@@ -1513,7 +1518,13 @@ def initialize_model_parallel(
15131518
config = get_current_vllm_config()
15141519
data_parallel_size = config.parallel_config.data_parallel_size
15151520
enable_elastic_ep = config.parallel_config.enable_elastic_ep
1521+
parallel_config = config.parallel_config
1522+
coord_store: Store | None = None
15161523
if enable_elastic_ep:
1524+
coord_store = get_cached_tcp_store_client(
1525+
parallel_config.data_parallel_master_ip,
1526+
parallel_config._coord_store_port,
1527+
)
15171528
# Use stateless world group for global information
15181529
world_size = get_world_group().world_size
15191530
rank = get_world_group().rank
@@ -1633,16 +1644,12 @@ def initialize_model_parallel(
16331644
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
16341645
group_ranks = [x.tolist() for x in group_ranks]
16351646
if enable_elastic_ep:
1636-
parallel_config = config.parallel_config
1637-
dp_ports = [
1638-
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
1639-
]
16401647
_DP = _init_stateless_group(
16411648
group_ranks,
16421649
"dp",
1643-
dp_ports,
16441650
parallel_config.data_parallel_master_ip,
16451651
backend,
1652+
coord_store=coord_store,
16461653
)
16471654
else:
16481655
_DP = init_model_parallel_group(
@@ -1665,16 +1672,12 @@ def initialize_model_parallel(
16651672
)
16661673
group_ranks = [x.tolist() for x in group_ranks]
16671674
if enable_elastic_ep:
1668-
parallel_config = config.parallel_config
1669-
ep_ports = [
1670-
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
1671-
]
16721675
_EP = _init_stateless_group(
16731676
group_ranks,
16741677
"ep",
1675-
ep_ports,
16761678
parallel_config.data_parallel_master_ip,
16771679
backend,
1680+
coord_store=coord_store,
16781681
)
16791682
else:
16801683
_EP = init_model_parallel_group(
@@ -1693,16 +1696,12 @@ def initialize_model_parallel(
16931696
and config.parallel_config.enable_eplb
16941697
):
16951698
if enable_elastic_ep:
1696-
eplb_ports = [
1697-
parallel_config.get_next_stateless_eplb_group_port()
1698-
for _ in group_ranks
1699-
]
17001699
_EPLB = _init_stateless_group(
17011700
group_ranks,
17021701
"eplb",
1703-
eplb_ports,
17041702
parallel_config.data_parallel_master_ip,
17051703
backend,
1704+
coord_store=coord_store,
17061705
)
17071706
else:
17081707
_EPLB = init_model_parallel_group(

0 commit comments

Comments
 (0)