Skip to content

Commit 5d251d2

Browse files
Make the _get_state function more Pythonic and concise
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3857036 commit 5d251d2

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

guides/orbax_checkpoint.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,8 @@ def __init__(
8181

8282
def _get_state(self):
8383
"""Gets the model state and metrics"""
84-
model_state = self._model.get_state_tree()
85-
state = {}
86-
metrics = None
87-
for k, v in model_state.items():
88-
if k == "metrics_variables":
89-
metrics = v
90-
else:
91-
state[k] = v
84+
state = self._model.get_state_tree().copy()
85+
metrics = state.pop("metrics_variables", None)
9286
return state, metrics
9387

9488
def save_state(self, epoch):

0 commit comments

Comments
 (0)