Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def _eval_batch(
global_step: Array,
params: Params,
func_state: FuncState | None,
opt_state: kfac_jax.Optimizer.State | optimizers.OptaxState,
opt_state: kfac_jax.Optimizer.State | optimizers.OptaxState | None,
rng: PRNGKey | None,
batch: Batch,
) -> dict[str, Array]:
Expand All @@ -725,7 +725,7 @@ def _eval_batch(

stats["loss"] = loss

if hasattr(opt_state, "data_seen"):
if opt_state is not None and hasattr(opt_state, "data_seen"):
stats["data_seen"] = opt_state.data_seen

return stats
Expand Down Expand Up @@ -868,12 +868,16 @@ def run_evaluation(
class JaxlineExperiment(SupervisedExperiment, experiment.AbstractExperiment):
"""A Jaxline supervised experiment."""

CHECKPOINT_ATTRS = {
"_params": "params",
"_params_polyak": "params_polyak",
"_state": "state",
"_opt_state": "opt_state",
}
@property
def CHECKPOINT_ATTRS(self) -> dict[str, str]:
attrs = {
"_params": "params",
"_params_polyak": "params_polyak",
"_state": "state",
}
if self.mode != "eval" or self._schedule_free_enabled:
attrs["_opt_state"] = "opt_state"
return attrs

NON_BROADCAST_CHECKPOINT_ATTRS = {"_python_step": "python_step"}

Expand Down
Loading