Skip to content

Incompatibility with Haiku #528

Open
@chriscarmona

Description

@chriscarmona

Reopening an issue regarding incompatibility with Haiku naming conventions (similar to previous issue). This is not problematic in v0.3.5

Sample code:

from jax import numpy as jnp
import orbax.checkpoint as ocp
import haiku as hk


@hk.transform
def forward_fn(inputs):
  # net = hk.Linear(output_size=2) # This works
  net = hk.nets.MLP(
      output_sizes=[2, 2], activate_final=True)  # This doesn't work
  return net(inputs)


prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))

ckpt_dir = '/tmp/my-checkpoints/'
orbax_mngr = ocp.CheckpointManager(
    ckpt_dir,
    {'state': ocp.PyTreeCheckpointer()},
    options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})

The error:

Traceback (most recent call last):
  File "/workspaces/modularbayes/examples/bar.py", line 23, in <module>
    orbax_mngr.save(step=0, items={'state': params})
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 496, in save
    self._checkpointers[k].save(item_dir, item, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 79, in save
    self._handler.save(tmpdir, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 818, in save
    asyncio.run(async_save(directory, item, *args, **kwargs))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 811, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 786, in async_save
    commit_futures = await asyncio.gather(*serialize_ops)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 893, in serialize
    open_future = ts.open(
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/mlp/~/linear_0.b" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
sys:1: RuntimeWarning: coroutine 'async_serialize' was never awaited

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions