Open
Description
Hi,
I spotted ocp.tree.serialize_tree but it seems the serialization logic won't work if you have empty leaves within a sequence. This happens quite frequently with optax where you'll end up with optax.EmptyState() within a tuple. Here's a minimal reproduction of this issue:
import optax
import orbax.checkpoint as ocp
tree = (0, optax.EmptyState(), 1) # or None, etc.
ocp.tree.serialize_tree(tree)
resulting in:
File .../orbax/checkpoint/tree/utils.py", line 79, in _extend_list
assert idx <= len(ls)
^^^^^^^^^^^^^^
AssertionError
I'm not sure what the ideal solution here is, I don't have enough context on what's the intended purpose of serialize_tree and deserialize_tree.