Open
Description
I am getting a number of errors checking for folder existence in _src/metadata/checkpoint.py e.g. _src/metadata/checkpoint.py", line 45, in _sanitize_metadata_path raise FileNotFoundError(f'Path does not exist: {path}')
when trying to create a gs:
checkpoint on TPU (v6e, v2-alpha-tpuv6e). For some reason the error does not happen elsewhere (e.g. locally on a Mac).
Sample code:
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
path = "somebucket/somepath"
checkpointer = ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
checkpointer.save(f"gs://{path}", (jnp.ones(10),), force=True)
print(checkpointer.restore(f"gs://{path}",(ocp.RestoreArgs(restore_type=np.ndarray),)))
Disabling those checks (patch attached) seems to resolve the issue, but is obviously more of a bandaid.