Skip to content

Commit e862207

Browse files
committed
Fix some formatting issues
1 parent 6e5a849 commit e862207

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

guides/orbax_checkpoint.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
**kwargs,
7676
):
7777
"""Initialize the Keras Orbax Checkpoint Manager.
78-
78+
7979
Args:
8080
model: The Keras model to checkpoint.
8181
checkpoint_dir: Directory path where checkpoints will be saved.
@@ -95,10 +95,10 @@ def __init__(
9595

9696
def _get_state(self):
9797
"""Gets the model state and metrics.
98-
98+
9999
This method retrieves the complete state tree from the model and separates
100100
the metrics variables from the rest of the state.
101-
101+
102102
Returns:
103103
A tuple containing:
104104
- state: A dictionary containing the model's state (weights, optimizer state, etc.)
@@ -149,7 +149,7 @@ def __init__(
149149
**kwargs,
150150
):
151151
"""Initialize the Orbax checkpoint callback.
152-
152+
153153
Args:
154154
model: The Keras model to checkpoint.
155155
checkpoint_dir: Directory path where checkpoints will be saved.
@@ -158,9 +158,6 @@ def __init__(
158158
steps_per_epoch: Number of steps per epoch. Default is 1.
159159
**kwargs: Additional keyword arguments to pass to Orbax's
160160
CheckpointManagerOptions.
161-
162-
Raises:
163-
ValueError: If the backend is not JAX.
164161
"""
165162
if keras.config.backend() != "jax":
166163
raise ValueError(

0 commit comments

Comments
 (0)