Skip to content

[tx] add retries for getting random port used for JAX coordinator address#1652

Open
andrewsykim wants to merge 1 commit into
NovaSky-AI:mainfrom
andrewsykim:ray-tx
Open

[tx] add retries for getting random port used for JAX coordinator address#1652
andrewsykim wants to merge 1 commit into
NovaSky-AI:mainfrom
andrewsykim:ray-tx

Conversation

@andrewsykim
Copy link
Copy Markdown
Contributor

Follow up to #1418 (comment)

…ress

Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the _get_random_port function in skyrl/backends/ray_jax.py to include a retry loop that attempts to find an available port up to 10 times. The reviewer identifies several issues with this approach, noting that socket.bind(("", 0)) already requests an available port from the OS, making collisions at that specific line unlikely. Furthermore, the feedback points out a race condition where the port is released before JAX can use it and criticizes the lack of backoff in the retry loop, suggesting that the retry logic should instead wrap the JAX initialization process.

Comment thread skyrl/backends/ray_jax.py
Comment on lines +13 to +21
# try a few different ports in case another process is using randomly assigned port
for _ in range(10):
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
continue
raise RuntimeError("Could not allocate a free port")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The retry logic added here is unlikely to resolve the common "port in use" issues encountered during JAX initialization, and the accompanying comment is somewhat misleading.

  1. Misleading Comment: socket.bind(("", 0)) requests any available port from the OS. It will not return a port that is already in use. An OSError here typically indicates system-wide ephemeral port exhaustion or a network configuration issue, not a collision with a specific port.
  2. Race Condition: The typical failure mode occurs because the port is released when the with block ends (line 18) and is then hijacked by another process before JAX can bind to it in the setup method (which happens much later in the orchestration). Retrying inside _get_random_port does not protect against this window of vulnerability.
  3. Tight Loop: The loop retries immediately without any backoff or delay. If the system is indeed out of ports, a tight loop is inefficient and unlikely to succeed.

To effectively handle port collisions, the retry logic should ideally encompass the jax.distributed.initialize call or the orchestration step where the port is actually consumed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant