Open
Description
Hi, we are trying out the orbax (0.4.1) AsyncCheckpointer (used through CheckpointManager). We are getting "Array has been deleted" errors. It seems as if the async checkpointer is trying to copy a jax.Array from device to memory, but that array is no longer available. The Orbax documentations says that "From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host", but I wonder if there any gotchas in how we should use orbax checkpointing.
Here is the stack trace:
Exception in thread Thread-314 (_finalize):
Traceback (most recent call last):
File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
self.run()
File "/usr/lib/python3.11/threading.py", line 982, in run
self._target(*self._args, **self._kwargs)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 956, in _finalize
self.wait_until_finished(join_finalize_thread=Fale)
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py", line 888, in wait_until_finished
checkpointer.wait_until_finished() # pytype: disable=attribute-error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 262, in wait_until_finished
self._async_manager.wait_until_finished()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 154, in wait_until_finished
self.check_for_errors()
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 145, in check_for_errors
raise exception # pylint: disable=raising-bad-type
^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/async_checkpointer.py", line 97, in _thread_func
future.result()
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 456, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/aggregate_handlers.py", line 75, in _serialize_fn
msgpack = msgpack_utils.msgpack_serialize(serializable_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 216, in msgpack_serialize
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/msgpack/__init__.py", line 36, in packb
return Packer(**kwargs).pack(o)
File "msgpack/_packer.pyx", line 285, in msgpack._cmsgpack.Packer._pack
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 78, in _msgpack_ext_pack
return msgpack.ExtType(_MsgpackExtType.NDARRAY, _ndarray_to_bytes(x))
^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/orbax/checkpoint/msgpack_utils.py", line 40, in _ndarray_to_bytes
arr = np.array(arr)
^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 377, in __array__
return np.asarray(self._value, dtype=dtype)
^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 340, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 562, in _value
self._check_if_deleted()
File ".venv/lib/python3.11/site-packages/jax/_src/array.py", line 530, in _check_if_deleted
raise RuntimeError(
RuntimeError: Array has been deleted with shape=float32[256].