Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion tests/v1/distributed/test_external_lb_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import socket
import threading
import time
from contextlib import AsyncExitStack
Expand All @@ -20,6 +21,29 @@
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
# Default tensor parallel size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
# Make sure CCL worker count is set for data parallelism
os.environ["CCL_WORKER_COUNT"] = str(DP_SIZE)


def is_port_available(port: int, host: str = "127.0.0.1") -> bool:
# Try to bind to the port to check if it's available. This is more reliable
# than trying to connect.
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((host, port))
return True
except OSError:
return False


def get_unique_port(start_port=8000):
"""Find an available port"""
port = start_port
while not is_port_available(port):
port += 1 # Increment until an available port is found
if port > start_port + 100: # Limit the search range
raise RuntimeError("No available ports")
return port


class ExternalLBServerManager:
Expand All @@ -44,6 +68,14 @@ def __init__(

def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for external LB mode."""

allocated_ports = []
last_port = 7999
for _ in range(self.dp_size):
port = get_unique_port(start_port=last_port + 1)
allocated_ports.append(port)
last_port = port

for rank in range(self.dp_size):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
Expand All @@ -60,7 +92,7 @@ def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + rank), # Different port for each rank
str(allocated_ports[rank]), # Different port for each rank
"--api-server-count",
str(self.api_server_count),
]
Expand Down