Open
Description
Working from a modified ImageNet Linen example, I've added two state attr for Polyak averaging ema values as so
@flax.struct.dataclass
class TrainState:
step: int
optimizer: flax.optim.Optimizer
model_state: Any
dynamic_scale: flax.optim.DynamicScale
ema_params: flax.core.FrozenDict = None # lazy init on first step
ema_model_state: flax.core.FrozenDict = None # lazy init on first step
Restoring the checkpoints with that state causes an error as the FrozenDicts get restored as dicts. I'm not sure if this is a bug or feature request (ie is this expected). I noticed there is registration fn for restoring state dict, FrozenDicts are among them, should that not cover this case? Or should I wrap my ema state in another class and register my own state dict restore fn that freezes the dicts.
I'm currently doing this hack after restore to work around the issue...
if step_offset > 0:
state = state.replace(
ema_params=flax.core.freeze(state.ema_params),
ema_model_state=flax.core.freeze(state.ema_model_state))