Skip to content

Commit 3857036

Browse files
committed
Add Orbax checkpointing guide
1 parent 9184480 commit 3857036

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

guides/orbax_checkpoint.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
Title: Orbax Checkpointing in Keras
3+
Author: [Samaneh Saadat](https://github.com/SamanehSaadat/)
4+
Date created: 2025/08/20
5+
Last modified: 2025/08/20
6+
Description: A guide on how to save Orbax checkpoints during model training with the Jax backend.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## Introduction
12+
Orbax is the default checkpointing library recommended for JAX ecosystem
13+
users. It is a high-level checkpointing library which provides functionality
14+
for both checkpoint management and composable and extensible serialization.
15+
This guide explains how to do Orbax checkpointing when training a model in
16+
the Jax backend.
17+
18+
The default `.keras` format doesn't support multi-host checkpointing so if
19+
you are using Keras distribution API for multi-host training, you need to
20+
use Orbax checkpointing.
21+
"""
22+
23+
"""
24+
## Setup
25+
Let's start by installing Orbax checkpointing library:
26+
"""
27+
28+
"""shell
29+
pip install -q -u orbax-checkpoint
30+
"""
31+
32+
"""
33+
We need to set the Keras backend to Jax as this guide is intended for the
34+
Jax backend. Then we import Keras and other libraries needed including the
35+
Orbax checkpointing library.
36+
"""
37+
38+
import os
39+
40+
os.environ["KERAS_BACKEND"] = "jax"
41+
42+
import keras
43+
import numpy as np
44+
import orbax.checkpoint as ocp
45+
46+
"""
47+
## Orbax Callback
48+
We need to create two main utilities to manage Orbax checkpointing in Keras:
49+
1. `KerasOrbaxCheckpointManager`: A wrapper around
50+
`orbax.checkpoint.CheckpointManager` for Keras models.
51+
`KerasOrbaxCheckpointManager` uses `Model`'s `get_state_tree` and
52+
`set_state_tree` APIs to save and restore the model variables.
53+
2. `OrbaxCheckpointCallback`: A Keras callback that uses
54+
`KerasOrbaxCheckpointManager` to automatically save and restore model states
55+
during training.
56+
57+
Orbax checkpointing in Keras is as simple as copying these utilities to your
58+
own codebase and passing `OrbaxCheckpointCallback` to the `fit`.
59+
"""
60+
61+
62+
class KerasOrbaxCheckpointManager(ocp.CheckpointManager):
63+
"""A wrapper over Orbax CheckpointManager for Keras with the Jax
64+
backend."""
65+
66+
def __init__(
67+
self,
68+
model,
69+
checkpoint_dir,
70+
max_to_keep=5,
71+
steps_per_epoch=1,
72+
**kwargs,
73+
):
74+
options = ocp.CheckpointManagerOptions(
75+
max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
76+
)
77+
self._model = model
78+
self._steps_per_epoch = steps_per_epoch
79+
self._checkpoint_dir = checkpoint_dir
80+
super().__init__(checkpoint_dir, options=options)
81+
82+
def _get_state(self):
83+
"""Gets the model state and metrics"""
84+
model_state = self._model.get_state_tree()
85+
state = {}
86+
metrics = None
87+
for k, v in model_state.items():
88+
if k == "metrics_variables":
89+
metrics = v
90+
else:
91+
state[k] = v
92+
return state, metrics
93+
94+
def save_state(self, epoch):
95+
"""Saves the model to the checkpoint directory.
96+
97+
Args:
98+
epoch: The epoch number at which the state is saved.
99+
"""
100+
state, metrics_value = self._get_state()
101+
self.save(
102+
epoch * self._steps_per_epoch,
103+
args=ocp.args.StandardSave(item=state),
104+
metrics=metrics_value,
105+
)
106+
107+
def restore_state(self, step=None):
108+
"""Restores the model from the checkpoint directory.
109+
110+
Args:
111+
step: The step number to restore the state from. Default=None
112+
restores the latest step.
113+
"""
114+
if step is None:
115+
step = self.latest_step()
116+
# Restore the model state only, not metrics.
117+
state, _ = self._get_state()
118+
restored_state = self.restore(step, args=ocp.args.StandardRestore(item=state))
119+
self._model.set_state_tree(restored_state)
120+
121+
122+
class OrbaxCheckpointCallback(keras.callbacks.Callback):
123+
"""A callback for checkpointing and restoring state using Orbax."""
124+
125+
def __init__(
126+
self,
127+
model,
128+
checkpoint_dir,
129+
max_to_keep=5,
130+
steps_per_epoch=1,
131+
**kwargs,
132+
):
133+
if keras.config.backend() != "jax":
134+
raise ValueError(
135+
"`OrbaxCheckpointCallback` is only supported on a "
136+
"`jax` backend. Provided backend is %s." % keras.config.backend()
137+
)
138+
self._checkpoint_manager = KerasOrbaxCheckpointManager(
139+
model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
140+
)
141+
142+
def on_train_begin(self, logs=None):
143+
if not self.model.built or not self.model.optimizer.built:
144+
raise ValueError(
145+
"To use `OrbaxCheckpointCallback`, your model and "
146+
"optimizer must be built before you call `fit()`."
147+
)
148+
latest_epoch = self._checkpoint_manager.latest_step()
149+
if latest_epoch is not None:
150+
print("Load Orbax checkpoint on_train_begin.")
151+
self._checkpoint_manager.restore_state(step=latest_epoch)
152+
153+
def on_epoch_end(self, epoch, logs=None):
154+
print("Save Orbax checkpoint on_epoch_end.")
155+
self._checkpoint_manager.save_state(epoch)
156+
157+
158+
"""
159+
## An Orbax checkpointing example
160+
Let's look at how we can use `OrbaxCheckpointCallback` to save Orbax
161+
checkpoints during the training. To get started, let's define a simple model
162+
and a toy training dataset.
163+
"""
164+
165+
166+
def get_model():
167+
# Create a simple model.
168+
inputs = keras.Input(shape=(32,))
169+
outputs = keras.layers.Dense(1, name="dense")(inputs)
170+
model = keras.Model(inputs, outputs)
171+
model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error")
172+
return model
173+
174+
175+
model = get_model()
176+
177+
x_train = np.random.random((128, 32))
178+
y_train = np.random.random((128, 1))
179+
180+
"""
181+
Then, we create an Orbax checkpointing callback and pass it to the
182+
`callbacks` argument in the `fit` function.
183+
"""
184+
185+
orbax_callback = OrbaxCheckpointCallback(
186+
model,
187+
checkpoint_dir="/tmp/ckpt",
188+
max_to_keep=1,
189+
steps_per_epoch=1,
190+
)
191+
history = model.fit(
192+
x_train,
193+
y_train,
194+
batch_size=32,
195+
epochs=3,
196+
verbose=0,
197+
validation_split=0.2,
198+
callbacks=[orbax_callback],
199+
)
200+
201+
"""
202+
Now if you look at the Orbax checkpoint directory, you can see all the files
203+
saved as part of Orbax checkpointing.
204+
"""

scripts/guides_master.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@
119119
"path": "keras_nnx_guide",
120120
"title": "How to use Keras with NNX backend",
121121
},
122+
{
123+
"path": "orbax_checkpoint",
124+
"title": "Orbax Checkpointing in Keras",
125+
},
122126
# {
123127
# "path": "preprocessing_layers",
124128
# "title": "Working with preprocessing layers",

0 commit comments

Comments
 (0)