-
Notifications
You must be signed in to change notification settings - Fork 82
Open
Labels
type:bugSomething isn't workingSomething isn't working
Description
hello! we are seeing the following failure pattern leading to timeouts in our medium sized (512 node) training runs on GPUs
- everyone finishes train step around the same time, and calls
manager.savearound the same time - save involves slice operations. when jax hits XLA ops in eager mode it compiles down to single op graphs
- get 10s-100s of slice calls, each go through compilation and autotuning
- compilation cache access (in our case, probably for others too) funnels to a single NFS path. some ranks get starved for file locks
- the starved ranks don't make it to barrier in time and cause timeout
example callstack of starved ranks
Thread 3266587 (idle): "MainThread"
backend_compile_and_load (jax/_src/compiler.py:362)
wrapper (jax/_src/profiler.py:384)
_compile_and_write_cache (jax/_src/compiler.py:746)
compile_or_get_cached (jax/_src/compiler.py:478)
_cached_compilation (jax/_src/interpreters/pxla.py:2843)
from_hlo (jax/_src/interpreters/pxla.py:3066)
compile (jax/_src/interpreters/pxla.py:2515)
_pjit_call_impl_python (jax/_src/pjit.py:1207)
_run_python_pjit (jax/_src/pjit.py:140)
cache_miss (jax/_src/pjit.py:255)
reraise_with_filtered_traceback (jax/_src/traceback_util.py:197)
apply_primitive (jax/_src/dispatch.py:91)
_slice_impl (jax/_src/lax/slicing.py:1484)
process_primitive (jax/_src/core.py:1208)
bind_with_trace (jax/_src/core.py:664)
_true_bind (jax/_src/core.py:652)
bind (jax/_src/core.py:636)
slice (jax/_src/lax/slicing.py:113)
slice_in_dim (jax/_src/lax/slicing.py:1016)
data (orbax/checkpoint/_src/serialization/replica_slices.py:100)
async_transfer_slice (orbax/checkpoint/_src/serialization/replica_slices.py:444)
transfer_arrays_to_host (orbax/checkpoint/_src/serialization/replica_slices.py:471)
_serialize_arrays_batches_without_dispatcher (orbax/checkpoint/_src/serialization/jax_array_handlers.py:363)
_serialize_arrays (orbax/checkpoint/_src/serialization/jax_array_handlers.py:447)
serialize (orbax/checkpoint/_src/serialization/jax_array_handlers.py:1073)
_logging_serialize (orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py:149)repro script
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
os.environ["JAX_PLATFORMS"] = "cpu"
import tempfile
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
jax.config.update("jax_log_compiles", True)
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(2, 2), ("data", "model"))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data", None))
pytree = {
f"p{i}": jax.device_put(jnp.ones(s), sharding)
for i, s in enumerate(
[(8, 4), (16, 8), (32, 16), (64, 32), (128, 64), (24, 12), (48, 6)]
)
}
with tempfile.TemporaryDirectory() as tmpdir:
with ocp.CheckpointManager(tmpdir) as mgr:
mgr.save(0, args=ocp.args.StandardSave(pytree))for now we're working around this by monkeypatching orbax s.t. it batches/jits the slices by mesh https://gist.github.com/jfc4050/27b8817e497b70a467e13720b5de20af.
Maybe there's a way to make jax not compile for eagerly called slice ops. Otherwise happy to open a PR if it would be helpful
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working