|
| 1 | +""" |
| 2 | +Title: Orbax Checkpointing in Keras |
| 3 | +Author: [Samaneh Saadat](https://github.com/SamanehSaadat/) |
| 4 | +Date created: 2025/08/20 |
| 5 | +Last modified: 2025/08/20 |
| 6 | +Description: A guide on how to save Orbax checkpoints during model training with the Jax backend. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## Introduction |
| 12 | +Orbax is the default checkpointing library recommended for JAX ecosystem |
| 13 | +users. It is a high-level checkpointing library which provides functionality |
| 14 | +for both checkpoint management and composable and extensible serialization. |
| 15 | +This guide explains how to do Orbax checkpointing when training a model in |
| 16 | +the Jax backend. |
| 17 | +
|
| 18 | +The default `.keras` format doesn't support multi-host checkpointing so if |
| 19 | +you are using Keras distribution API for multi-host training, you need to |
| 20 | +use Orbax checkpointing. |
| 21 | +""" |
| 22 | + |
| 23 | +""" |
| 24 | +## Setup |
| 25 | +Let's start by installing Orbax checkpointing library: |
| 26 | +""" |
| 27 | + |
| 28 | +"""shell |
| 29 | +pip install -q -u orbax-checkpoint |
| 30 | +""" |
| 31 | + |
| 32 | +""" |
| 33 | +We need to set the Keras backend to Jax as this guide is intended for the |
| 34 | +Jax backend. Then we import Keras and other libraries needed including the |
| 35 | +Orbax checkpointing library. |
| 36 | +""" |
| 37 | + |
| 38 | +import os |
| 39 | + |
| 40 | +os.environ["KERAS_BACKEND"] = "jax" |
| 41 | + |
| 42 | +import keras |
| 43 | +import numpy as np |
| 44 | +import orbax.checkpoint as ocp |
| 45 | + |
| 46 | +""" |
| 47 | +## Orbax Callback |
| 48 | +We need to create two main utilities to manage Orbax checkpointing in Keras: |
| 49 | +1. `KerasOrbaxCheckpointManager`: A wrapper around |
| 50 | +`orbax.checkpoint.CheckpointManager` for Keras models. |
| 51 | +`KerasOrbaxCheckpointManager` uses `Model`'s `get_state_tree` and |
| 52 | +`set_state_tree` APIs to save and restore the model variables. |
| 53 | +2. `OrbaxCheckpointCallback`: A Keras callback that uses |
| 54 | +`KerasOrbaxCheckpointManager` to automatically save and restore model states |
| 55 | +during training. |
| 56 | +
|
| 57 | +Orbax checkpointing in Keras is as simple as copying these utilities to your |
| 58 | +own codebase and passing `OrbaxCheckpointCallback` to the `fit`. |
| 59 | +""" |
| 60 | + |
| 61 | + |
| 62 | +class KerasOrbaxCheckpointManager(ocp.CheckpointManager): |
| 63 | + """A wrapper over Orbax CheckpointManager for Keras with the Jax |
| 64 | + backend.""" |
| 65 | + |
| 66 | + def __init__( |
| 67 | + self, |
| 68 | + model, |
| 69 | + checkpoint_dir, |
| 70 | + max_to_keep=5, |
| 71 | + steps_per_epoch=1, |
| 72 | + **kwargs, |
| 73 | + ): |
| 74 | + options = ocp.CheckpointManagerOptions( |
| 75 | + max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs |
| 76 | + ) |
| 77 | + self._model = model |
| 78 | + self._steps_per_epoch = steps_per_epoch |
| 79 | + self._checkpoint_dir = checkpoint_dir |
| 80 | + super().__init__(checkpoint_dir, options=options) |
| 81 | + |
| 82 | + def _get_state(self): |
| 83 | + """Gets the model state and metrics""" |
| 84 | + model_state = self._model.get_state_tree() |
| 85 | + state = {} |
| 86 | + metrics = None |
| 87 | + for k, v in model_state.items(): |
| 88 | + if k == "metrics_variables": |
| 89 | + metrics = v |
| 90 | + else: |
| 91 | + state[k] = v |
| 92 | + return state, metrics |
| 93 | + |
| 94 | + def save_state(self, epoch): |
| 95 | + """Saves the model to the checkpoint directory. |
| 96 | +
|
| 97 | + Args: |
| 98 | + epoch: The epoch number at which the state is saved. |
| 99 | + """ |
| 100 | + state, metrics_value = self._get_state() |
| 101 | + self.save( |
| 102 | + epoch * self._steps_per_epoch, |
| 103 | + args=ocp.args.StandardSave(item=state), |
| 104 | + metrics=metrics_value, |
| 105 | + ) |
| 106 | + |
| 107 | + def restore_state(self, step=None): |
| 108 | + """Restores the model from the checkpoint directory. |
| 109 | +
|
| 110 | + Args: |
| 111 | + step: The step number to restore the state from. Default=None |
| 112 | + restores the latest step. |
| 113 | + """ |
| 114 | + if step is None: |
| 115 | + step = self.latest_step() |
| 116 | + # Restore the model state only, not metrics. |
| 117 | + state, _ = self._get_state() |
| 118 | + restored_state = self.restore(step, args=ocp.args.StandardRestore(item=state)) |
| 119 | + self._model.set_state_tree(restored_state) |
| 120 | + |
| 121 | + |
| 122 | +class OrbaxCheckpointCallback(keras.callbacks.Callback): |
| 123 | + """A callback for checkpointing and restoring state using Orbax.""" |
| 124 | + |
| 125 | + def __init__( |
| 126 | + self, |
| 127 | + model, |
| 128 | + checkpoint_dir, |
| 129 | + max_to_keep=5, |
| 130 | + steps_per_epoch=1, |
| 131 | + **kwargs, |
| 132 | + ): |
| 133 | + if keras.config.backend() != "jax": |
| 134 | + raise ValueError( |
| 135 | + "`OrbaxCheckpointCallback` is only supported on a " |
| 136 | + "`jax` backend. Provided backend is %s." % keras.config.backend() |
| 137 | + ) |
| 138 | + self._checkpoint_manager = KerasOrbaxCheckpointManager( |
| 139 | + model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs |
| 140 | + ) |
| 141 | + |
| 142 | + def on_train_begin(self, logs=None): |
| 143 | + if not self.model.built or not self.model.optimizer.built: |
| 144 | + raise ValueError( |
| 145 | + "To use `OrbaxCheckpointCallback`, your model and " |
| 146 | + "optimizer must be built before you call `fit()`." |
| 147 | + ) |
| 148 | + latest_epoch = self._checkpoint_manager.latest_step() |
| 149 | + if latest_epoch is not None: |
| 150 | + print("Load Orbax checkpoint on_train_begin.") |
| 151 | + self._checkpoint_manager.restore_state(step=latest_epoch) |
| 152 | + |
| 153 | + def on_epoch_end(self, epoch, logs=None): |
| 154 | + print("Save Orbax checkpoint on_epoch_end.") |
| 155 | + self._checkpoint_manager.save_state(epoch) |
| 156 | + |
| 157 | + |
| 158 | +""" |
| 159 | +## An Orbax checkpointing example |
| 160 | +Let's look at how we can use `OrbaxCheckpointCallback` to save Orbax |
| 161 | +checkpoints during the training. To get started, let's define a simple model |
| 162 | +and a toy training dataset. |
| 163 | +""" |
| 164 | + |
| 165 | + |
| 166 | +def get_model(): |
| 167 | + # Create a simple model. |
| 168 | + inputs = keras.Input(shape=(32,)) |
| 169 | + outputs = keras.layers.Dense(1, name="dense")(inputs) |
| 170 | + model = keras.Model(inputs, outputs) |
| 171 | + model.compile(optimizer=keras.optimizers.Adam(), loss="mean_squared_error") |
| 172 | + return model |
| 173 | + |
| 174 | + |
| 175 | +model = get_model() |
| 176 | + |
| 177 | +x_train = np.random.random((128, 32)) |
| 178 | +y_train = np.random.random((128, 1)) |
| 179 | + |
| 180 | +""" |
| 181 | +Then, we create an Orbax checkpointing callback and pass it to the |
| 182 | +`callbacks` argument in the `fit` function. |
| 183 | +""" |
| 184 | + |
| 185 | +orbax_callback = OrbaxCheckpointCallback( |
| 186 | + model, |
| 187 | + checkpoint_dir="/tmp/ckpt", |
| 188 | + max_to_keep=1, |
| 189 | + steps_per_epoch=1, |
| 190 | +) |
| 191 | +history = model.fit( |
| 192 | + x_train, |
| 193 | + y_train, |
| 194 | + batch_size=32, |
| 195 | + epochs=3, |
| 196 | + verbose=0, |
| 197 | + validation_split=0.2, |
| 198 | + callbacks=[orbax_callback], |
| 199 | +) |
| 200 | + |
| 201 | +""" |
| 202 | +Now if you look at the Orbax checkpoint directory, you can see all the files |
| 203 | +saved as part of Orbax checkpointing. |
| 204 | +""" |
0 commit comments