We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents f12e8b1 + 5d251d2 commit 895a43fCopy full SHA for 895a43f
guides/orbax_checkpoint.py
@@ -81,14 +81,8 @@ def __init__(
81
82
def _get_state(self):
83
"""Gets the model state and metrics"""
84
- model_state = self._model.get_state_tree()
85
- state = {}
86
- metrics = None
87
- for k, v in model_state.items():
88
- if k == "metrics_variables":
89
- metrics = v
90
- else:
91
- state[k] = v
+ state = self._model.get_state_tree().copy()
+ metrics = state.pop("metrics_variables", None)
92
return state, metrics
93
94
def save_state(self, epoch):
0 commit comments