Skip to content

Commit afa6b6c

Browse files
pengwu22Superjomn
authored andcommitted
[ci] fix: occasional CI failures caused by sglang server port conflicts (verl-project#5310)
### What does this PR do? Fix CI failures caused by occasional port conflicts ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [x] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent deafcef commit afa6b6c

File tree

6 files changed

+50
-30
lines changed

6 files changed

+50
-30
lines changed

verl/checkpoint_engine/hccl_checkpoint_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada
165165

166166
def _start_zmq_server(self):
167167
self.ip = ray.util.get_node_ip_address().strip("[]")
168-
self.zmq_port, self.listen_sock = get_free_port(self.ip)
168+
self.zmq_port, _ = get_free_port(self.ip)
169169

170170
context = zmq.Context()
171171
self.socket = context.socket(zmq.PUB)

verl/checkpoint_engine/nccl_checkpoint_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def build_topology(cls, trainer_world_size: int, rollout_world_size: int, metada
164164

165165
def _start_zmq_server(self):
166166
self.ip = ray.util.get_node_ip_address().strip("[]")
167-
self.listen_port, self.listen_sock = get_free_port(self.ip)
167+
self.listen_port, _ = get_free_port(self.ip)
168168

169169
context = zmq.Context()
170170
self.socket = context.socket(zmq.PUB)

verl/checkpoint_engine/nixl_checkpoint_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def get_agent_metadata(self) -> NixlAgentMetadata:
8282

8383
def start_zmq_server(self):
8484
self.ip = ray.util.get_node_ip_address().strip("[]")
85-
self.listen_port, self.listen_sock = get_free_port(self.ip)
85+
self.listen_port, _ = get_free_port(self.ip)
8686

8787
context = zmq.asyncio.Context()
8888
self.socket = context.socket(zmq.PULL)

verl/utils/net_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,22 @@ def is_valid_ipv6_address(address: str) -> bool:
7070
return False
7171

7272

73-
def get_free_port(address: str) -> tuple[int, socket.socket]:
74-
family = socket.AF_INET
75-
if is_valid_ipv6_address(address):
76-
family = socket.AF_INET6
73+
def get_free_port(address: str, with_alive_sock: bool = False) -> tuple[int, socket.socket | None]:
74+
"""Find a free port on the given address.
75+
76+
By default the socket is closed internally, suitable for immediate use.
77+
Set with_alive_sock=True to keep the socket open as a port reservation,
78+
preventing other calls from getting the same port. The caller is
79+
responsible for closing the socket before the port is actually bound
80+
by the target service (e.g. NCCL, uvicorn).
81+
"""
82+
family = socket.AF_INET6 if is_valid_ipv6_address(address) else socket.AF_INET
7783

7884
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
7985
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
80-
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
8186
sock.bind((address, 0))
82-
8387
port = sock.getsockname()[1]
84-
return port, sock
88+
if with_alive_sock:
89+
return port, sock
90+
sock.close()
91+
return port, None

verl/workers/rollout/sglang_rollout/async_sglang_server.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,19 @@ def __init__(
123123
profiler_config = None
124124
self.profiler_controller = DistProfiler(self.replica_rank, config=profiler_config, tool_config=tool_config)
125125

126-
# used for NCCL process group
127-
if self.node_rank == 0:
126+
# For multi-node, we need dist_init_addr so nodes can coordinate NCCL init.
127+
# For single-node, let SGLang handle port selection internally via nccl_port,
128+
# which also avoids port conflicts.
129+
self._master_address = None
130+
self._master_port = None
131+
self._master_sock = None
132+
if self.nnodes > 1 and self.node_rank == 0:
128133
self._master_address = self._server_address
129-
self._master_port, self._master_sock = get_free_port(self._server_address)
134+
self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True)
130135
logger.info(
131136
f"SGLangHttpServer, replica_rank: {self.replica_rank}, "
132137
f"master address: {self._master_address}, port: {self._master_port}"
133138
)
134-
else:
135-
self._master_address = None
136-
self._master_port = None
137139

138140
def get_master_address(self):
139141
"""Get master address and port for init NCCL process group."""
@@ -145,10 +147,13 @@ def get_server_address(self):
145147
return self._server_address, self._server_port
146148

147149
async def launch_server(self, master_address: str = None, master_port: int = None):
148-
if self.node_rank != 0:
149-
assert master_address and master_port, "non-master node should provide master address and port"
150-
self._master_address = master_address
151-
self._master_port = master_port
150+
if self.nnodes > 1:
151+
if self.node_rank != 0:
152+
assert master_address and master_port, "non-master node should provide master address and port"
153+
self._master_address = master_address
154+
self._master_port = master_port
155+
else:
156+
self._master_sock.close()
152157

153158
engine_kwargs = self.config.get("engine_kwargs", {}).get("sglang", {}) or {}
154159
attention_backend = engine_kwargs.pop("attention_backend", None)
@@ -167,11 +172,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
167172
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
168173
else:
169174
raise ValueError(f"Currently only support fp8 quantization, got: {quantization}")
170-
dist_init_addr = (
171-
f"[{self._master_address}]:{self._master_port}"
172-
if is_valid_ipv6_address(self._master_address)
173-
else f"{self._master_address}:{self._master_port}"
174-
)
175175
infer_tp = self.config.tensor_model_parallel_size * self.config.data_parallel_size
176176
args = {
177177
"model_path": self.model_config.local_path,
@@ -186,7 +186,6 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
186186
"ep_size": self.config.expert_parallel_size,
187187
"node_rank": self.node_rank,
188188
"load_format": self.config.load_format,
189-
"dist_init_addr": dist_init_addr,
190189
"nnodes": self.nnodes,
191190
"trust_remote_code": self.model_config.trust_remote_code,
192191
"max_running_requests": self.config.get("max_num_seqs", None),
@@ -202,6 +201,16 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
202201
**engine_kwargs,
203202
}
204203

204+
# Only set dist_init_addr for multi-node; for single-node, let SGLang
205+
# handle port selection internally via nccl_port to avoid conflicts.
206+
if self.nnodes > 1:
207+
dist_init_addr = (
208+
f"[{self._master_address}]:{self._master_port}"
209+
if is_valid_ipv6_address(self._master_address)
210+
else f"{self._master_address}:{self._master_port}"
211+
)
212+
args["dist_init_addr"] = dist_init_addr
213+
205214
if self.config.prometheus.enable:
206215
if self.config.prometheus.served_model_name:
207216
# Extract model name from path if it's a full path
@@ -510,7 +519,9 @@ async def launch_servers(self):
510519
self.servers.append(server)
511520

512521
# launch http server in each node
513-
master_address, master_port = await self.servers[0].get_master_address.remote()
522+
master_address, master_port = None, None
523+
if self.nnodes > 1:
524+
master_address, master_port = await self.servers[0].get_master_address.remote()
514525
await asyncio.gather(
515526
*[
516527
server.launch_server.remote(master_address=master_address, master_port=master_port)

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ def __init__(
153153
if self.node_rank == 0:
154154
self._master_address = self._server_address
155155
# used for torch.distributed.init_process_group
156-
self._master_port, self._master_sock = get_free_port(self._server_address)
156+
self._master_port, self._master_sock = get_free_port(self._server_address, with_alive_sock=True)
157157
# used for data parallel: --data-parallel-address, --data-parallel-rpc-port
158-
self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address)
159-
self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address)
158+
self._dp_rpc_port, self._dp_rpc_sock = get_free_port(self._server_address, with_alive_sock=True)
159+
self._dp_master_port, self._dp_master_sock = get_free_port(self._server_address, with_alive_sock=True)
160160
else:
161161
self._master_address = None
162162
self._master_port = None
@@ -424,6 +424,8 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
424424
# 3. launch server
425425
if self.node_rank == 0:
426426
self._master_sock.close()
427+
self._dp_rpc_sock.close()
428+
self._dp_master_sock.close()
427429
await self.run_server(server_args)
428430
else:
429431
# TODO: avoid connect before master_sock close

0 commit comments

Comments
 (0)