Open
Description
Hi Everyone,
I've been trying to checkpoint training using Orbax in a project linked here project. When I test the code locally I'm able checkpoint successfully, but when training in a TPU v4-32 VM I encounter an issue related to directories not being found.
I've put together a simpler example using code from the Orbax docs, which outputs a similar error.
import jax
import numpy as np
import orbax.checkpoint as ocp
from jax.experimental import mesh_utils
def test_checkpointing():
jax.distributed.initialize()
if jax.process_index() == 0:
print("Number of devices: ", jax.device_count())
print("Local devices: ", jax.local_device_count())
devices = mesh_utils.create_device_mesh((jax.device_count(),))
mesh = jax.sharding.Mesh(devices, ("data",))
sharding = jax.NamedSharding(
mesh,
jax.sharding.PartitionSpec(),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
"a": np.arange(16),
"b": np.ones(16),
}
state = jax.tree_util.tree_map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
print(jax.tree_util.tree_map(lambda x: x.shape, state))
path = ocp.test_utils.erase_and_create_empty("/tmp/checkpoint")
global_metadata = {'global_property': 'foo'}
with ocp.CheckpointManager(path, item_names=("state", "custom_data"), metadata=global_metadata) as mngr:
mngr.save(
0,
args=ocp.args.Composite(
state=ocp.args.PyTreeSave(state),
custom_data=ocp.args.JsonSave({"lang": "en", "version": 1.2}),
),
)
print("Checkpoint saved!")
if __name__ == "__main__":
test_checkpointing()
which appears to succeed on the process index, but fail on the rest of the hosts.
Here is the error I see:
Traceback (most recent call last):
File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 50, in <module>
test_checkpointing()
File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 35, in test_checkpointing
ckptr.save(0, state)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1110, in save
self._checkpointer.save(save_directory, args=args)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/async_checkpointer.py", line 328, in save
commit_ops = asyncio.run(self._handler.async_save(tmpdir, args=ckpt_args))
File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/lib/python3.10/asyncio/base_events.py", line 646, in run_until_complete
return future.result()
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py", line 358, in async_save
path.mkdir(parents=False, exist_ok=True)
File "/home/simonsenan/.local/lib/python3.10/site-packages/etils/epath/gpath.py", line 205, in mkdir
self._backend.mkdir(self._path_str, exist_ok=exist_ok, mode=mode)
File "/home/simonsenan/.local/lib/python3.10/site-packages/etils/epath/backend.py", line 180, in mkdir
os.mkdir(path, mode=mode)
FileNotFoundError: [Errno 2] No such file or directory: '/tmp/checkpoint/0.orbax-checkpoint-tmp-0/default'
Finally, here's the command I normally use when installing all my dependencies on the vm
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html flax optax pandas numpy scipy wandb tqdm orbax-checkpoint gcsfs
Is this issue related to my sharding and the directory not being created on all of the hosts or something on the Orbax end? Any assistance is greatly appreciated!