Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 62 additions & 67 deletions src/srtctl/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ def _is_dp_mode(self, mode: WorkerMode) -> bool:
"""Check if this mode uses Data Parallel + Expert Parallel pattern.

DP+EP mode is detected when data-parallel-size is set in the mode's config.
In this mode, each GPU runs its own process (rather than TP across GPUs).
We launch one dynamo.vllm process per node and let vLLM spawn the local
DP engines (hybrid_lb), so all GPUs on a node share a CUDA namespace and
symm-mem / NVLink P2P fast paths work.
"""
config = self.get_config_for_mode(mode)
return config.get("data-parallel-size") is not None or config.get("data_parallel_size") is not None
Expand All @@ -196,8 +198,9 @@ def endpoints_to_processes(
) -> list[Process]:
"""Convert endpoints to processes.

For DP+EP mode (data-parallel-size set), creates one process per GPU.
For standard TP mode, creates one process per node.
For DP+EP mode (data-parallel-size set), creates one process per NODE
with all local GPUs visible — vLLM internally spawns the local DP engines
(hybrid_lb). For standard TP mode, also one process per node.
"""
from srtctl.core.topology import NodePortAllocator, Process, endpoints_to_processes

Expand All @@ -208,71 +211,47 @@ def endpoints_to_processes(
# Standard TP mode: one process per node
return endpoints_to_processes(endpoints, base_sys_port=base_sys_port)

# DP+EP mode: one process per GPU
# DP+EP mode: one process per node
processes: list[Process] = []
current_sys_port = base_sys_port
port_allocator = NodePortAllocator()

for endpoint in endpoints:
if not self._is_dp_mode(endpoint.mode):
# Non-DP endpoints get standard processing
# (This shouldn't happen in practice since all modes should be consistent)
for node_rank, node in enumerate(endpoint.nodes):
is_leader = node_rank == 0
http_port = port_allocator.next_http_port(node) if is_leader else 0
bootstrap_port = (
port_allocator.next_bootstrap_port(node) if endpoint.mode == "prefill" and is_leader else None
# local_dp_size = local DP ranks per node (= number of GPUs the
# endpoint claims on each node). Reserve that many KV-events ports
# because dynamo opens one ZMQ publisher per local rank inside the
# same dynamo.vllm process (port = base ± dp_rank).
local_dp_size = len(endpoint.gpu_indices)

for node_rank, node in enumerate(endpoint.nodes):
is_leader = node_rank == 0
http_port = port_allocator.next_http_port(node) if is_leader else 0
bootstrap_port = (
port_allocator.next_bootstrap_port(node) if endpoint.mode == "prefill" and is_leader else None
)
kv_events_port = port_allocator.next_kv_events_port()
if self._is_dp_mode(endpoint.mode):
# Reserve trailing slots so the next allocation can't collide
# with this process's per-rank publishers on the same host.
for _ in range(local_dp_size - 1):
port_allocator.next_kv_events_port()
nixl_port = port_allocator.next_nixl_port()

processes.append(
Process(
node=node,
gpu_indices=endpoint.gpu_indices,
sys_port=current_sys_port,
http_port=http_port,
endpoint_mode=endpoint.mode,
endpoint_index=endpoint.index,
node_rank=node_rank,
bootstrap_port=bootstrap_port,
kv_events_port=kv_events_port,
nixl_port=nixl_port,
)
kv_events_port = port_allocator.next_kv_events_port()
nixl_port = port_allocator.next_nixl_port()

processes.append(
Process(
node=node,
gpu_indices=endpoint.gpu_indices,
sys_port=current_sys_port,
http_port=http_port,
endpoint_mode=endpoint.mode,
endpoint_index=endpoint.index,
node_rank=node_rank,
bootstrap_port=bootstrap_port,
kv_events_port=kv_events_port,
nixl_port=nixl_port,
)
)
current_sys_port += 1
else:
# DP+EP mode: one process per GPU
# Each process gets a single GPU and a unique dp_rank
dp_rank = 0
for _node_rank, node in enumerate(endpoint.nodes):
for gpu_idx in sorted(endpoint.gpu_indices):
is_leader = dp_rank == 0
http_port = port_allocator.next_http_port(node) if is_leader else 0
bootstrap_port = (
port_allocator.next_bootstrap_port(node)
if endpoint.mode == "prefill" and is_leader
else None
)
kv_events_port = port_allocator.next_kv_events_port()
nixl_port = port_allocator.next_nixl_port()

processes.append(
Process(
node=node,
gpu_indices=frozenset([gpu_idx]), # Single GPU per process
sys_port=current_sys_port,
http_port=http_port,
endpoint_mode=endpoint.mode,
endpoint_index=endpoint.index,
node_rank=dp_rank, # dp_rank stored in node_rank for now
bootstrap_port=bootstrap_port,
kv_events_port=kv_events_port,
nixl_port=nixl_port,
)
)
current_sys_port += 1
dp_rank += 1
)
current_sys_port += 1

return processes

Expand Down Expand Up @@ -349,21 +328,37 @@ def build_worker_command(
is_dp_mode = self._is_dp_mode(mode)

if is_dp_mode:
# DP+EP mode: each GPU runs its own process
# process.node_rank is the dp_rank (set in endpoints_to_processes)
dp_rank = process.node_rank
# Hybrid DP+EP: one dynamo.vllm process per node, vLLM spawns the
# local DP engines internally. process.node_rank is the node index
# within the endpoint (0 for leader, 1+ for workers).
#
# We deliberately do NOT pass --data-parallel-rank: vLLM auto-flips
# to external_lb the moment that flag is set and forces
# data_parallel_size_local=1, which puts each rank in its own CUDA
# namespace and breaks symm-mem rendezvous (overlapping cuda:0).
# See vllm/engine/arg_utils.py:1702-1717.
local_dp_size = len(process.gpu_indices)
start_rank = process.node_rank * local_dp_size
dp_rpc_port = config.pop("data-parallel-rpc-port", None) or config.pop("data_parallel_rpc_port", 13345)

cmd.extend(
[
"--data-parallel-rank",
str(dp_rank),
"--data-parallel-hybrid-lb",
"--data-parallel-size-local",
str(local_dp_size),
"--data-parallel-start-rank",
str(start_rank),
"--data-parallel-address",
leader_ip,
"--data-parallel-rpc-port",
str(dp_rpc_port),
]
)
# We pass --data-parallel-hybrid-lb explicitly because vLLM's auto-detect
# (arg_utils.py:1721 `if self.data_parallel_start_rank and not headless`)
# uses Python truthiness, so a leader node with start_rank=0 silently
# falls out of hybrid_lb and rejects worker engines as "Remote engine N
# must use --headless unless in external or hybrid dp lb mode".
# Note: --data-parallel-size is added via _config_to_cli_args from vllm_config
elif is_multi_node:
# Standard TP+PP multi-node coordination flags
Expand Down
89 changes: 39 additions & 50 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,8 +967,8 @@ def test_dp_mode_detection(self):
assert backend_dp._is_dp_mode("decode") is True
assert backend_dp._get_dp_size("prefill") == 16

def test_dp_mode_creates_per_gpu_processes(self):
"""Test that DP mode creates one process per GPU instead of per node."""
def test_dp_mode_creates_per_node_processes(self):
"""Test that DP mode creates one process per node (hybrid_lb)."""
from srtctl.backends import VLLMProtocol, VLLMServerConfig
from srtctl.core.topology import Endpoint

Expand All @@ -989,31 +989,21 @@ def test_dp_mode_creates_per_gpu_processes(self):

processes = backend.endpoints_to_processes([endpoint])

# Should create 16 processes (1 per GPU), not 2 (1 per node)
assert len(processes) == 16
# One process per node. vLLM spawns local DP engines inside each.
assert len(processes) == 2
assert [p.node for p in processes] == ["node0", "node1"]

# Each process should have exactly 1 GPU
# Each process owns the full local GPU set so the inner engines share
# a CUDA namespace (no per-process CUDA_VISIBLE_DEVICES restriction).
for proc in processes:
assert len(proc.gpu_indices) == 1

# First 8 processes on node0, next 8 on node1
node0_processes = [p for p in processes if p.node == "node0"]
node1_processes = [p for p in processes if p.node == "node1"]
assert len(node0_processes) == 8
assert len(node1_processes) == 8

# GPU indices should be 0-7 on each node
node0_gpus = {list(p.gpu_indices)[0] for p in node0_processes}
node1_gpus = {list(p.gpu_indices)[0] for p in node1_processes}
assert node0_gpus == {0, 1, 2, 3, 4, 5, 6, 7}
assert node1_gpus == {0, 1, 2, 3, 4, 5, 6, 7}
assert proc.gpu_indices == frozenset(range(8))

# dp_rank (stored in node_rank) should go from 0 to 15
dp_ranks = [p.node_rank for p in processes]
assert dp_ranks == list(range(16))
# node_rank is the node index within the endpoint (used to compute
# --data-parallel-start-rank in build_worker_command).
assert [p.node_rank for p in processes] == [0, 1]

def test_dp_mode_command_includes_dp_flags(self):
"""Test that DP mode command includes correct DP flags instead of TP flags."""
"""Test that DP mode command uses hybrid_lb flags (size-local + start-rank)."""
from pathlib import Path
from unittest.mock import MagicMock, patch

Expand All @@ -1030,43 +1020,31 @@ def test_dp_mode_command_includes_dp_flags(self):
)
)

# Create a process representing GPU 5 with dp_rank=5
# Worker node (node_rank=1) of a 2-node endpoint with 8 GPUs each.
# Hybrid_lb: one process per node owns the full local GPU set.
process = Process(
node="node0",
gpu_indices=frozenset([5]),
sys_port=8081,
node="node1",
gpu_indices=frozenset(range(8)),
sys_port=8082,
http_port=0,
endpoint_mode="prefill",
endpoint_index=0,
node_rank=5, # dp_rank
node_rank=1,
)

# Create endpoint_processes spanning 2 nodes
endpoint_processes = [
Process(
node="node0",
gpu_indices=frozenset([i]),
sys_port=8081 + i,
http_port=0,
endpoint_mode="prefill",
endpoint_index=0,
node_rank=i,
)
for i in range(8)
] + [
Process(
node="node1",
gpu_indices=frozenset([i]),
sys_port=8089 + i,
http_port=0,
gpu_indices=frozenset(range(8)),
sys_port=8081,
http_port=30000,
endpoint_mode="prefill",
endpoint_index=0,
node_rank=8 + i,
)
for i in range(8)
node_rank=0,
),
process,
]

# Mock runtime context
mock_runtime = MagicMock()
mock_runtime.model_path = Path("/model")
mock_runtime.is_hf_model = False
Expand All @@ -1078,17 +1056,25 @@ def test_dp_mode_command_includes_dp_flags(self):
runtime=mock_runtime,
)

# Should include DP flags
assert "--data-parallel-rank" in cmd
assert "5" in cmd # dp_rank = 5
# hybrid_lb flags. We pass --data-parallel-hybrid-lb explicitly because
# vLLM's auto-detect uses truthy on start_rank, so leader (start_rank=0)
# would silently fall out of hybrid_lb otherwise.
assert "--data-parallel-hybrid-lb" in cmd
assert "--data-parallel-size-local" in cmd
i = cmd.index("--data-parallel-size-local")
assert cmd[i + 1] == "8" # 8 local DP ranks per node
assert "--data-parallel-start-rank" in cmd
i = cmd.index("--data-parallel-start-rank")
assert cmd[i + 1] == "8" # node_rank=1 → start_rank = 1 * 8
assert "--data-parallel-address" in cmd
assert "10.0.0.1" in cmd
assert "--data-parallel-rpc-port" in cmd
assert "13345" in cmd
assert "--data-parallel-size" in cmd
assert "16" in cmd

# Should NOT include TP multi-node flags
# external_lb / TP multi-node flags must not leak in
assert "--data-parallel-rank" not in cmd
assert "--master-addr" not in cmd
assert "--nnodes" not in cmd
assert "--node-rank" not in cmd
Expand Down Expand Up @@ -1246,6 +1232,9 @@ def test_tp_mode_command_includes_multinode_flags(self):

# Should NOT include DP flags
assert "--data-parallel-rank" not in cmd
assert "--data-parallel-hybrid-lb" not in cmd
assert "--data-parallel-size-local" not in cmd
assert "--data-parallel-start-rank" not in cmd
assert "--data-parallel-address" not in cmd

def test_tp_mode_leader_not_headless(self):
Expand Down