-
Notifications
You must be signed in to change notification settings - Fork 27
Closed
Description
After making data-clamping possible for multiple state variables (#773), I found that I missed something. When using checkpointing, line 443 in integrate.py uses an example_key to create the dummy_external that is used for extending externals' state arrays. If one clamps different states and there are a different number of cells per state that are clamped, the code will attempt to concatenate a dummy_external of the wrong shape, that is, the number of cells corresponding to the first clamp x time.
The offending lines:
dummy_external = jnp.zeros(
(size_difference, externals[example_key].shape[1])
)
for key in externals.keys():
externals[key] = jnp.concatenate([externals[key], dummy_external])
This should be easily fixable by just creating the dummy_external in the loop.
for key in externals.keys():
dummy_external = jnp.zeros(
(size_difference, externals[key].shape[1])
)
externals[key] = jnp.concatenate([externals[key], dummy_external])
Let me know if there are any objections or alternate/preferred solutions; otherwise I can submit a PR soon.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels