Skip to content

ocp.tree.serialize_tree filtering logic for sequences with empty leaves #1356

Open
@JesseFarebro

Description

@JesseFarebro

Hi,

I spotted ocp.tree.serialize_tree but it seems the serialization logic won't work if you have empty leaves within a sequence. This happens quite frequently with optax where you'll end up with optax.EmptyState() within a tuple. Here's a minimal reproduction of this issue:

import optax
import orbax.checkpoint as ocp

tree = (0, optax.EmptyState(), 1) # or None, etc.
ocp.tree.serialize_tree(tree)

resulting in:

  File .../orbax/checkpoint/tree/utils.py", line 79, in _extend_list
    assert idx <= len(ls)
           ^^^^^^^^^^^^^^
AssertionError

I'm not sure what the ideal solution here is, I don't have enough context on what's the intended purpose of serialize_tree and deserialize_tree.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions