Skip to content

Multi-host Checkpointing Error #999

Open
@ssenan

Description

@ssenan

Hi Everyone,

I've been trying to checkpoint training using Orbax in a project linked here project. When I test the code locally I'm able checkpoint successfully, but when training in a TPU v4-32 VM I encounter an issue related to directories not being found.

I've put together a simpler example using code from the Orbax docs, which outputs a similar error.

import jax
import numpy as np
import orbax.checkpoint as ocp
from jax.experimental import mesh_utils


def test_checkpointing():
    jax.distributed.initialize()

    if jax.process_index() == 0:
        print("Number of devices: ", jax.device_count())
        print("Local devices: ", jax.local_device_count())

    devices = mesh_utils.create_device_mesh((jax.device_count(),))
    mesh = jax.sharding.Mesh(devices, ("data",))
    sharding = jax.NamedSharding(
        mesh,
        jax.sharding.PartitionSpec(),
    )

    create_sharded_array = lambda x: jax.device_put(x, sharding)
    state = {
        "a": np.arange(16),
        "b": np.ones(16),
    }
    state = jax.tree_util.tree_map(create_sharded_array, state)
    abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
    print(jax.tree_util.tree_map(lambda x: x.shape, state))

    path = ocp.test_utils.erase_and_create_empty("/tmp/checkpoint")
    
    global_metadata = {'global_property': 'foo'}
    with ocp.CheckpointManager(path, item_names=("state", "custom_data"), metadata=global_metadata) as mngr:
        mngr.save(
            0,
            args=ocp.args.Composite(
                state=ocp.args.PyTreeSave(state),
                custom_data=ocp.args.JsonSave({"lang": "en", "version": 1.2}),
            ),
        )

    print("Checkpoint saved!")


if __name__ == "__main__":
    test_checkpointing()

which appears to succeed on the process index, but fail on the rest of the hosts.

Here is the error I see:

Traceback (most recent call last):
  File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 50, in <module>
    test_checkpointing()
  File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 35, in test_checkpointing
    ckptr.save(0, state)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1110, in save
    self._checkpointer.save(save_directory, args=args)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/async_checkpointer.py", line 328, in save
    commit_ops = asyncio.run(self._handler.async_save(tmpdir, args=ckpt_args))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 646, in run_until_complete
    return future.result()
  File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 358, in async_save
    path.mkdir(parents=False, exist_ok=True)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/etils/epath/gpath.py", line 205, in mkdir
    self._backend.mkdir(self._path_str, exist_ok=exist_ok, mode=mode)
  File "/home/simonsenan/.local/lib/python3.10/site-packages/etils/epath/backend.py", line 180, in mkdir
    os.mkdir(path, mode=mode)
FileNotFoundError: [Errno 2] No such file or directory: '/tmp/checkpoint/0.orbax-checkpoint-tmp-0/default'

Finally, here's the command I normally use when installing all my dependencies on the vm

pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html  flax optax pandas numpy scipy wandb tqdm orbax-checkpoint gcsfs

Is this issue related to my sharding and the directory not being created on all of the hosts or something on the Orbax end? Any assistance is greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions