Skip to content

Commit 8935518

Browse files
github-actions[bot]rjpowerclaude
committed
fix: register endpoint before jax.distributed.initialize to avoid deadlock
jax.distributed.initialize() blocks until all processes connect, so registering the endpoint after calling it would deadlock — tasks 1..N-1 would never discover the coordinator address. JAX's internal gRPC retry handles the brief window between registration and the coordinator starting to listen. Co-authored-by: Russell Power <rjpower@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 3ccc7d0 commit 8935518

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

lib/iris/src/iris/runtime/jax_init.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,16 @@ def initialize_jax(
9696
if task_index == 0:
9797
bound_port = job_info.ports.get("jax", port)
9898
coordinator = f"{job_info.advertise_host}:{bound_port}"
99-
# Initialize first so the gRPC coordinator is listening before we
100-
# advertise the address to other tasks via the endpoint registry.
101-
jax.distributed.initialize(coordinator, job_info.num_tasks, task_index)
99+
# Register the endpoint first so other tasks can discover the
100+
# coordinator address. jax.distributed.initialize() blocks until
101+
# all processes connect, so registering after would deadlock.
102+
# JAX's internal gRPC retry handles the brief window between
103+
# endpoint registration and the coordinator starting to listen.
102104
endpoint_id = ctx.registry.register(endpoint_name, coordinator)
103105
# Best-effort cleanup: if the process crashes, the controller's
104106
# cascade delete on task cleanup handles endpoint removal.
105107
atexit.register(ctx.registry.unregister, endpoint_id)
108+
jax.distributed.initialize(coordinator, job_info.num_tasks, task_index)
106109
else:
107110
coordinator = _poll_for_coordinator(ctx.resolver, endpoint_name, poll_timeout, poll_interval)
108111
jax.distributed.initialize(coordinator, job_info.num_tasks, task_index)

0 commit comments

Comments
 (0)