Skip to content

Commit d600675

Browse files
[iris] Make initialize_jax idempotent (#6320)
initialize_jax now returns early when jax.distributed.is_initialized(), so calling it after JAX has already initialized the XLA backend is a no-op instead of raising. JAX 0.9+ added a backends_are_initialized() check inside jax.distributed.initialize() that turns this ordering into a hard error; pre-0.9 silently no-op'd. Lets callers initialize explicitly before levanter's DistributedConfig.initialize().
1 parent 1d065b5 commit d600675

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

lib/iris/src/iris/runtime/jax_init.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,17 @@ def initialize_jax(
123123
"""
124124
import jax # noqa: PLC0415 # optional dep: jax (iris does not depend on jax)
125125

126+
# Idempotent: if the caller already initialized jax.distributed (either by
127+
# calling us explicitly first, or because user code touched JAX before our
128+
# call site — JAX 0.9+ raises on a second `jax.distributed.initialize()`
129+
# via the `xla_bridge.backends_are_initialized()` check), just return.
130+
# Callers that touch JAX before levanter.initialize (e.g. via `hax.named`
131+
# → `jnp.asarray` building loss-config args) can call `initialize_jax()`
132+
# themselves first; levanter's later call lands here and is a no-op.
133+
if jax.distributed.is_initialized():
134+
logger.info("jax.distributed already initialized; skipping")
135+
return
136+
126137
# TPU has its own coordinator discovery via the TPU runtime, so avoid the
127138
# Iris endpoint dance. We still call JAX distributed initialization to
128139
# create the host-side distributed client used by Levanter multihost

0 commit comments

Comments
 (0)