Open
Description
Reopening an issue regarding incompatibility with Haiku naming conventions (similar to previous issue). This is not problematic in v0.3.5
Sample code:
from jax import numpy as jnp
import orbax.checkpoint as ocp
import haiku as hk
@hk.transform
def forward_fn(inputs):
# net = hk.Linear(output_size=2) # This works
net = hk.nets.MLP(
output_sizes=[2, 2], activate_final=True) # This doesn't work
return net(inputs)
prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))
ckpt_dir = '/tmp/my-checkpoints/'
orbax_mngr = ocp.CheckpointManager(
ckpt_dir,
{'state': ocp.PyTreeCheckpointer()},
options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})
The error:
Traceback (most recent call last):
File "/workspaces/modularbayes/examples/bar.py", line 23, in <module>
orbax_mngr.save(step=0, items={'state': params})
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 496, in save
self._checkpointers[k].save(item_dir, item, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 79, in save
self._handler.save(tmpdir, *args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 818, in save
asyncio.run(async_save(directory, item, *args, **kwargs))
File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 811, in async_save
commit_futures = await self.async_save(*args, **kwargs) # pytype: disable=bad-return-type
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 786, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 893, in serialize
open_future = ts.open(
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/mlp/~/linear_0.b" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
sys:1: RuntimeWarning: coroutine 'async_serialize' was never awaited