Skip to content

Checkpointing issue with multiple data clamps #785

@kyralianaka

Description

@kyralianaka

After making data-clamping possible for multiple state variables (#773), I found that I missed something. When using checkpointing, line 443 in integrate.py uses an example_key to create the dummy_external that is used for extending externals' state arrays. If one clamps different states and there are a different number of cells per state that are clamped, the code will attempt to concatenate a dummy_external of the wrong shape, that is, the number of cells corresponding to the first clamp x time.

The offending lines:

            dummy_external = jnp.zeros(
                (size_difference, externals[example_key].shape[1])
            )
            for key in externals.keys():
                externals[key] = jnp.concatenate([externals[key], dummy_external])

This should be easily fixable by just creating the dummy_external in the loop.

            for key in externals.keys():
                dummy_external = jnp.zeros(
                    (size_difference, externals[key].shape[1])
                )
                externals[key] = jnp.concatenate([externals[key], dummy_external])

Let me know if there are any objections or alternate/preferred solutions; otherwise I can submit a PR soon.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions