Skip to content

orbax calls to jax.lax.slice_in_dim trigger compile storm #2934

@jfc4050

Description

@jfc4050

hello! we are seeing the following failure pattern leading to timeouts in our medium sized (512 node) training runs on GPUs

  1. everyone finishes train step around the same time, and calls manager.save around the same time
  2. save involves slice operations. when jax hits XLA ops in eager mode it compiles down to single op graphs
  3. get 10s-100s of slice calls, each go through compilation and autotuning
  4. compilation cache access (in our case, probably for others too) funnels to a single NFS path. some ranks get starved for file locks
  5. the starved ranks don't make it to barrier in time and cause timeout

example callstack of starved ranks

Thread 3266587 (idle): "MainThread"
    backend_compile_and_load (jax/_src/compiler.py:362)
    wrapper (jax/_src/profiler.py:384)
    _compile_and_write_cache (jax/_src/compiler.py:746)
    compile_or_get_cached (jax/_src/compiler.py:478)
    _cached_compilation (jax/_src/interpreters/pxla.py:2843)
    from_hlo (jax/_src/interpreters/pxla.py:3066)
    compile (jax/_src/interpreters/pxla.py:2515)
    _pjit_call_impl_python (jax/_src/pjit.py:1207)
    _run_python_pjit (jax/_src/pjit.py:140)
    cache_miss (jax/_src/pjit.py:255)
    reraise_with_filtered_traceback (jax/_src/traceback_util.py:197)
    apply_primitive (jax/_src/dispatch.py:91)
    _slice_impl (jax/_src/lax/slicing.py:1484)
    process_primitive (jax/_src/core.py:1208)
    bind_with_trace (jax/_src/core.py:664)
    _true_bind (jax/_src/core.py:652)
    bind (jax/_src/core.py:636)
    slice (jax/_src/lax/slicing.py:113)
    slice_in_dim (jax/_src/lax/slicing.py:1016)
    data (orbax/checkpoint/_src/serialization/replica_slices.py:100)
    async_transfer_slice (orbax/checkpoint/_src/serialization/replica_slices.py:444)
    transfer_arrays_to_host (orbax/checkpoint/_src/serialization/replica_slices.py:471)
    _serialize_arrays_batches_without_dispatcher (orbax/checkpoint/_src/serialization/jax_array_handlers.py:363)
    _serialize_arrays (orbax/checkpoint/_src/serialization/jax_array_handlers.py:447)
    serialize (orbax/checkpoint/_src/serialization/jax_array_handlers.py:1073)
    _logging_serialize (orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py:149)

repro script

import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
os.environ["JAX_PLATFORMS"] = "cpu"

import tempfile

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp

jax.config.update("jax_log_compiles", True)

mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(2, 2), ("data", "model"))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data", None))

pytree = {
    f"p{i}": jax.device_put(jnp.ones(s), sharding)
    for i, s in enumerate(
        [(8, 4), (16, 8), (32, 16), (64, 32), (128, 64), (24, 12), (48, 6)]
    )
}

with tempfile.TemporaryDirectory() as tmpdir:
    with ocp.CheckpointManager(tmpdir) as mgr:
        mgr.save(0, args=ocp.args.StandardSave(pytree))

for now we're working around this by monkeypatching orbax s.t. it batches/jits the slices by mesh https://gist.github.com/jfc4050/27b8817e497b70a467e13720b5de20af.

Maybe there's a way to make jax not compile for eagerly called slice ops. Otherwise happy to open a PR if it would be helpful

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions