File tree Expand file tree Collapse file tree 1 file changed +12
-8
lines changed
Expand file tree Collapse file tree 1 file changed +12
-8
lines changed Original file line number Diff line number Diff line change @@ -704,7 +704,7 @@ def _eval_batch(
704704 global_step : Array ,
705705 params : Params ,
706706 func_state : FuncState | None ,
707- opt_state : kfac_jax .Optimizer .State | optimizers .OptaxState ,
707+ opt_state : kfac_jax .Optimizer .State | optimizers .OptaxState | None ,
708708 rng : PRNGKey | None ,
709709 batch : Batch ,
710710 ) -> dict [str , Array ]:
@@ -725,7 +725,7 @@ def _eval_batch(
725725
726726 stats ["loss" ] = loss
727727
728- if hasattr (opt_state , "data_seen" ):
728+ if opt_state is not None and hasattr (opt_state , "data_seen" ):
729729 stats ["data_seen" ] = opt_state .data_seen
730730
731731 return stats
@@ -868,12 +868,16 @@ def run_evaluation(
868868class JaxlineExperiment (SupervisedExperiment , experiment .AbstractExperiment ):
869869 """A Jaxline supervised experiment."""
870870
871- CHECKPOINT_ATTRS = {
872- "_params" : "params" ,
873- "_params_polyak" : "params_polyak" ,
874- "_state" : "state" ,
875- "_opt_state" : "opt_state" ,
876- }
871+ @property
872+ def CHECKPOINT_ATTRS (self ) -> dict [str , str ]:
873+ attrs = {
874+ "_params" : "params" ,
875+ "_params_polyak" : "params_polyak" ,
876+ "_state" : "state" ,
877+ }
878+ if self .mode != "eval" or self ._schedule_free_enabled :
879+ attrs ["_opt_state" ] = "opt_state"
880+ return attrs
877881
878882 NON_BROADCAST_CHECKPOINT_ATTRS = {"_python_step" : "python_step" }
879883
You can’t perform that action at this time.
0 commit comments