Skip to content

Issues restoring checkpoint of struct.dataclass w/ FrozenDict attr #676

Open
@rwightman

Description

@rwightman

Working from a modified ImageNet Linen example, I've added two state attr for Polyak averaging ema values as so

@flax.struct.dataclass
class TrainState:
    step: int
    optimizer: flax.optim.Optimizer
    model_state: Any
    dynamic_scale: flax.optim.DynamicScale
    ema_params: flax.core.FrozenDict = None  # lazy init on first step
    ema_model_state: flax.core.FrozenDict = None   # lazy init on first step

Restoring the checkpoints with that state causes an error as the FrozenDicts get restored as dicts. I'm not sure if this is a bug or feature request (ie is this expected). I noticed there is registration fn for restoring state dict, FrozenDicts are among them, should that not cover this case? Or should I wrap my ema state in another class and register my own state dict restore fn that freezes the dicts.

I'm currently doing this hack after restore to work around the issue...

    if step_offset > 0:
        state = state.replace(
            ema_params=flax.core.freeze(state.ema_params),
            ema_model_state=flax.core.freeze(state.ema_model_state))

Metadata

Metadata

Assignees

Labels

Priority: P2 - no scheduleBest effort response and resolution. We have no plan to work on this at the moment.Status: pull requests welcomeWe agree with the direction proposed, feel free to give it a shot and file a pull request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions