Skip to content

Commit 22ae079

Browse files
james-martensKfacJaxDev
authored andcommitted
Excluding opt_state from eval worker when possible.
PiperOrigin-RevId: 858564989
1 parent 25dd648 commit 22ae079

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

examples/training.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def _eval_batch(
704704
global_step: Array,
705705
params: Params,
706706
func_state: FuncState | None,
707-
opt_state: kfac_jax.Optimizer.State | optimizers.OptaxState,
707+
opt_state: kfac_jax.Optimizer.State | optimizers.OptaxState | None,
708708
rng: PRNGKey | None,
709709
batch: Batch,
710710
) -> dict[str, Array]:
@@ -725,7 +725,7 @@ def _eval_batch(
725725

726726
stats["loss"] = loss
727727

728-
if hasattr(opt_state, "data_seen"):
728+
if opt_state is not None and hasattr(opt_state, "data_seen"):
729729
stats["data_seen"] = opt_state.data_seen
730730

731731
return stats
@@ -868,12 +868,16 @@ def run_evaluation(
868868
class JaxlineExperiment(SupervisedExperiment, experiment.AbstractExperiment):
869869
"""A Jaxline supervised experiment."""
870870

871-
CHECKPOINT_ATTRS = {
872-
"_params": "params",
873-
"_params_polyak": "params_polyak",
874-
"_state": "state",
875-
"_opt_state": "opt_state",
876-
}
871+
@property
872+
def CHECKPOINT_ATTRS(self) -> dict[str, str]:
873+
attrs = {
874+
"_params": "params",
875+
"_params_polyak": "params_polyak",
876+
"_state": "state",
877+
}
878+
if self.mode != "eval" or self._schedule_free_enabled:
879+
attrs["_opt_state"] = "opt_state"
880+
return attrs
877881

878882
NON_BROADCAST_CHECKPOINT_ATTRS = {"_python_step": "python_step"}
879883

0 commit comments

Comments
 (0)