Skip to content

Commit 3b375af

Browse files
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore (#21762)
* Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore Supports following feature - Asynchronous Checkpointing - Composite Checkpointing - Preservation Policies - Save Decision Policies - Transformations - Custom Handlers * Fix unused variable in orbax checkpoint test * fixed failing cases * fixed review comments * Improve OrbaxCheckpoint implementation - Remove conditional export decorator to ensure OrbaxCheckpoint is always available - Remove unnecessary exception handling in state tree operations - Update process index check comment for clarity - Format code to comply with 80-character line limit - Add distribution_lib modules for backend-specific distributed training support * Fix code formatting and remove unused variable - Remove unused 'result' variable in _reconstruct_state_tree_with_values - Fix long comment line in test file - Apply code formatting changes * Add OrbaxCheckpoint callback with conditional exports and improved test handling - Implement OrbaxCheckpoint callback for async checkpointing with state tree handling - Add conditional exports for optional orbax-checkpoint dependency - Use pytest.importorskip for clean optional dependency testing - Ensure graceful handling when orbax-checkpoint is not installed * Improve OrbaxCheckpoint: preserve nested structures, enhance tests - Preserve nested state tree structures instead of flattening for better layer name preservation - Add backward compatibility for old flattened format checkpoints - Simplify test class by using self.get_temp_dir() instead of setUp/tearDown - Remove silent pytest.importorskip, add explicit skip conditions for backend-specific tests - Move process_id function from backend to distribution module - Update imports to use centralized LazyModule for orbax.checkpoint - Test across all backends (JAX, TensorFlow, PyTorch) - all passing * Fixed review comments * Migration to Orbax V1 * Fix sklearn wrapper CI tests by marking pipeline consistency checks as expected failures Neural networks are inherently non-deterministic, so pipeline consistency checks should be skipped rather than fail. Added check_pipeline_consistency to EXPECTED_FAILED_CHECKS for all sklearn wrapper types. * made distributed structure proper * Fixed sav decision between keras and orbax * Optimize Orbax checkpoint for JAX backend - Avoid unnecessary numpy conversion in _get_state_tree() for JAX backend - Preserve JAX arrays during saving instead of converting to numpy - Maintain cross-backend compatibility with proper loading conversions - Update async waiting to use CheckpointManager.wait_until_finished() - Implement AlwaysSavePolicy for reliable save decisions - Add expected failures for sklearn tests due to neural network non-determinism * Optimize Orbax checkpoint for JAX backend with compatibility check - Preserve JAX arrays during saving when jax.monitoring.record_scalar is available - Fall back to numpy conversion for older JAX versions that don't have record_scalar - Maintain cross-backend compatibility while avoiding unnecessary conversions - Update async waiting to use CheckpointManager.wait_until_finished() - Implement AlwaysSavePolicy for reliable save decisions - Add expected failures for sklearn tests due to neural network non-determinism * added checkpointer.wait() * Improve OrbaxCheckpoint callback with optimizations and cleanup - Optimize JAX array handling: avoid unnecessary numpy conversions for JAX >= 0.7.0 - Simplify step counting: use _total_batches_seen directly instead of dual mechanisms - Remove impossible error checks and verbose messages - Clean up unused Orbax exports that violated import policies - Update error message for consistency - All changes maintain backward compatibility and pass tests across JAX/TensorFlow/PyTorch backends * Simplify OrbaxCheckpoint API to match ModelCheckpoint parity - Remove extra features: save_metadata, save_data_iterator, post_finalization_callback, save_decision_policy, keep_period - Remove loading methods: load_checkpoint, load_latest, all_steps, _restore_model_state_from_full_tree - Replace save_optimizer_state/save_metrics_state with save_weights_only parameter - Add comprehensive test coverage for all remaining functionality - Maintain async saving and preservation policies as Orbax-specific advantages - All tests pass across JAX/TensorFlow/PyTorch backends * Removed the experimental import * Add comprehensive OrbaxCheckpoint tests with loading verification - Add test_checkpoint_loading: Verifies weights can be loaded from checkpoints - Add test_checkpoint_loading_weights_only: Tests save_weights_only=True loading - Add test_checkpoint_loading_with_optimizer_state: Tests full state loading with optimizer - Fix array comparison logic for JAX, TensorFlow, and PyTorch backends - Ensure all lines are within 80-character limit - All tests pass on JAX, TensorFlow, and PyTorch backends * Improve OrbaxCheckpoint: complete state preservation, cross-backend compatibility, and comprehensive testing - Add complete model state saving (trainable/non-trainable vars, optimizer, metrics) - Simplify save_weights_only logic to use full state tree when saving complete state - Remove unnecessary try-except fallback for wait() method (V1 API always has it) - Add comprehensive test coverage (13 tests) for all state components - Ensure cross-backend compatibility (JAX, TensorFlow, PyTorch) - Remove version dependencies and conditional imports - Update requirements-common.txt with orbax-checkpoint dependency * Add back try-except fallback for wait() method to support older Orbax versions * Use hasattr check instead of try-except for wait() method compatibility * Add JAX monitoring compatibility: mock jax.monitoring.record_scalar when missing - Prevents AttributeError in CI environments with older JAX versions - Adds no-op lambda function when record_scalar is not available - Ensures tests run across different JAX versions - All 13 tests pass on JAX, TensorFlow, and PyTorch backends * Re-run CI * Refactor LazyModule to use OrbaxLazyModule subclass for cleaner orbax.checkpoint.v1 import handling * Re-run CI * Changed the order of the param
1 parent 74fba84 commit 3b375af

File tree

7 files changed

+904
-0
lines changed

7 files changed

+904
-0
lines changed

keras/api/_tf_keras/keras/callbacks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from keras.src.callbacks.model_checkpoint import (
2020
ModelCheckpoint as ModelCheckpoint,
2121
)
22+
from keras.src.callbacks.orbax_checkpoint import (
23+
OrbaxCheckpoint as OrbaxCheckpoint,
24+
)
2225
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
2326
from keras.src.callbacks.reduce_lr_on_plateau import (
2427
ReduceLROnPlateau as ReduceLROnPlateau,

keras/api/callbacks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from keras.src.callbacks.model_checkpoint import (
2020
ModelCheckpoint as ModelCheckpoint,
2121
)
22+
from keras.src.callbacks.orbax_checkpoint import (
23+
OrbaxCheckpoint as OrbaxCheckpoint,
24+
)
2225
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
2326
from keras.src.callbacks.reduce_lr_on_plateau import (
2427
ReduceLROnPlateau as ReduceLROnPlateau,

keras/src/callbacks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
99
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
1010
from keras.src.callbacks.monitor_callback import MonitorCallback
11+
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
1112
from keras.src.callbacks.progbar_logger import ProgbarLogger
1213
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
1314
from keras.src.callbacks.remote_monitor import RemoteMonitor
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
import warnings
2+
3+
import numpy as np
4+
5+
from keras.src import backend
6+
from keras.src import tree
7+
from keras.src.api_export import keras_export
8+
from keras.src.callbacks.monitor_callback import (
9+
MonitorCallback, # For metric monitoring logic
10+
)
11+
from keras.src.utils.io_utils import print_msg
12+
from keras.src.utils.module_utils import ocp
13+
14+
# Context and AsyncOptions are accessed through the lazy-loaded ocp module
15+
16+
# JAX monitoring compatibility: ensure record_scalar exists
17+
# to prevent AttributeError in older JAX versions
18+
try:
19+
import jax
20+
21+
if not hasattr(jax.monitoring, "record_scalar"):
22+
jax.monitoring.record_scalar = lambda *args, **kwargs: None
23+
except ImportError:
24+
pass
25+
26+
27+
def _get_state_tree(model):
28+
"""Get the complete model state as a nested tree structure."""
29+
# For JAX backend, preserve native arrays for performance
30+
# For other backends, convert to numpy arrays
31+
if backend.backend() == "jax":
32+
state_tree = model.get_state_tree()
33+
did_numpy_conversion = False
34+
else:
35+
state_tree = model.get_state_tree(value_format="numpy_array")
36+
did_numpy_conversion = True
37+
38+
# Convert numpy scalar types to Python types for Orbax compatibility
39+
# Only needed when we did numpy conversion
40+
if did_numpy_conversion:
41+
42+
def convert_scalars(obj):
43+
if isinstance(obj, np.ndarray) and obj.ndim == 0:
44+
# Convert 0-dimensional numpy arrays (scalars) to Python types
45+
return obj.item()
46+
elif isinstance(obj, np.generic):
47+
# Convert numpy scalar types (like np.float32) to Python types
48+
return obj.item()
49+
else:
50+
return obj
51+
52+
return tree.map_structure(convert_scalars, state_tree)
53+
else:
54+
return state_tree
55+
56+
57+
@keras_export("keras.callbacks.OrbaxCheckpoint")
58+
class OrbaxCheckpoint(MonitorCallback):
59+
"""Callback to save and load model state using Orbax with a similar API to
60+
ModelCheckpoint.
61+
62+
This callback saves the model's weights and optimizer state asynchronously
63+
using Orbax, allowing training to continue without blocking for I/O.
64+
65+
Example:
66+
67+
```python
68+
model.compile(loss=..., optimizer=..., metrics=['accuracy'])
69+
70+
EPOCHS = 10
71+
checkpoint_dir = '/tmp/ckpt'
72+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
73+
directory=checkpoint_dir,
74+
monitor='val_accuracy',
75+
mode='max',
76+
save_best_only=True)
77+
78+
# Model is saved at the end of every epoch, if it's the best seen so far.
79+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
80+
81+
# Alternatively, save checkpoints every N batches -
82+
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
83+
directory=checkpoint_dir,
84+
save_freq=100) # Save every 100 batches
85+
86+
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
87+
```
88+
89+
Args:
90+
directory: path to the directory where to save the checkpoints.
91+
monitor: The metric name to monitor (e.g., 'val_loss').
92+
verbose: Verbosity mode, 0 or 1.
93+
save_best_only: if `save_best_only=True`, it only saves when the model
94+
is considered the "best" based on the monitored quantity.
95+
save_weights_only: if `save_weights_only=True`, only the model's
96+
weights will be saved. Otherwise, the full model state
97+
(weights, non-trainable variables, optimizer state, and
98+
metrics state) will be saved. Defaults to False.
99+
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
100+
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
101+
max_to_keep: Integer, maximum number of recent checkpoints to keep.
102+
If None, keeps all. Defaults to 1.
103+
save_on_background: Boolean, whether to save asynchronously in the
104+
background. Defaults to True.
105+
initial_value_threshold: Floating point initial "best" value for the
106+
monitor, used with `save_best_only`.
107+
"""
108+
109+
def __init__(
110+
self,
111+
directory,
112+
monitor="val_loss",
113+
verbose=0,
114+
save_best_only=False,
115+
save_weights_only=False,
116+
mode="auto",
117+
save_freq="epoch",
118+
initial_value_threshold=None,
119+
max_to_keep=1,
120+
save_on_background=True,
121+
):
122+
# Ensure orbax is available
123+
ocp.initialize()
124+
125+
# Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
126+
# logic
127+
super().__init__(monitor, mode, initial_value_threshold)
128+
129+
self.directory = directory
130+
self.verbose = verbose
131+
self.save_best_only = save_best_only
132+
self.save_weights_only = save_weights_only
133+
self.save_freq = save_freq
134+
self.max_to_keep = max_to_keep
135+
self.save_on_background = save_on_background
136+
self._batches_seen_since_last_saving = 0
137+
self._last_batch_seen = 0
138+
self._current_epoch = 0 # Keep track of epoch
139+
self._total_batches_seen = 0 # Global batch counter for step tracking
140+
141+
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
142+
raise ValueError(
143+
f"Unrecognized save_freq: {self.save_freq}. "
144+
"Expected save_freq are 'epoch' or integer values"
145+
)
146+
147+
# --- Orbax Checkpointer Setup (V1 API) ---
148+
policies = []
149+
if max_to_keep is not None:
150+
policies.append(
151+
ocp.training.preservation_policies.LatestN(max_to_keep)
152+
)
153+
154+
# Use AnyPreservationPolicy to combine them.
155+
preservation_policy = None
156+
if policies:
157+
preservation_policy = (
158+
ocp.training.preservation_policies.AnyPreservationPolicy(
159+
policies
160+
)
161+
)
162+
163+
# Create the V1 Checkpointer with direct parameter passing
164+
# Orbax will handle directory creation on all processes as needed
165+
self.checkpointer = ocp.training.Checkpointer(
166+
directory=directory,
167+
preservation_policy=preservation_policy,
168+
)
169+
170+
def _should_save_on_batch(self, batch):
171+
"""Check if we should save on this batch."""
172+
if self.save_freq == "epoch":
173+
return False
174+
175+
if batch <= self._last_batch_seen: # New epoch.
176+
add_batches = batch + 1
177+
else:
178+
add_batches = batch - self._last_batch_seen
179+
self._batches_seen_since_last_saving += add_batches
180+
self._last_batch_seen = batch
181+
self._total_batches_seen += add_batches
182+
183+
if self._batches_seen_since_last_saving >= self.save_freq:
184+
self._batches_seen_since_last_saving = 0
185+
return True
186+
return False
187+
188+
def _save_checkpoint(self, step, logs=None):
189+
"""Save a checkpoint at the given step."""
190+
191+
# --- Prepare Composite State (Backend-Agnostic) ---
192+
state_tree = _get_state_tree(self.model)
193+
194+
# Save the nested state structures directly (preserving layer
195+
# names and structure)
196+
if self.save_weights_only:
197+
composite_state = {
198+
"trainable_variables": state_tree["trainable_variables"],
199+
}
200+
if "non_trainable_variables" in state_tree:
201+
composite_state["non_trainable_variables"] = state_tree[
202+
"non_trainable_variables"
203+
]
204+
else:
205+
composite_state = state_tree
206+
207+
# --- Save Logic (V1 API) ---
208+
# All processes participate in distributed checkpointing
209+
# Checkpointer is configured to save unconditionally when
210+
# save_pytree is called
211+
if self.verbose > 0:
212+
print_msg(
213+
f"OrbaxCheckpoint: Triggering async save for step {step}..."
214+
)
215+
216+
# Use a single with statement. If context_options is empty,
217+
# Context() uses defaults.
218+
with ocp.Context():
219+
if self.save_on_background:
220+
self.checkpointer.save_pytree_async(step, composite_state)
221+
else:
222+
self.checkpointer.save_pytree(step, composite_state)
223+
224+
def on_train_batch_end(self, batch, logs=None):
225+
if self._should_save_on_batch(batch):
226+
# Handle save_best_only logic for batch-level saving
227+
should_save = True
228+
if self.save_best_only:
229+
current = logs.get(self.monitor) if logs else None
230+
if current is None:
231+
warnings.warn(
232+
f"Can save best model only with {self.monitor} "
233+
f"available, skipping save at batch {batch}.",
234+
stacklevel=2,
235+
)
236+
should_save = False
237+
elif not self._is_improvement(current, self.best):
238+
should_save = False
239+
else:
240+
# Update best value when there's improvement
241+
self.best = current
242+
243+
if should_save:
244+
# Use global batch count for Orbax save step
245+
step = self._total_batches_seen
246+
self._save_checkpoint(step=step, logs=logs)
247+
248+
def on_epoch_end(self, epoch, logs=None):
249+
self._current_epoch = epoch
250+
if self.monitor_op is None:
251+
self._set_monitor_op() # From MonitorCallback
252+
253+
# For save_freq="epoch", save at every epoch
254+
should_save = self.save_freq == "epoch"
255+
256+
# Handle save_best_only logic
257+
if should_save and self.save_best_only:
258+
current = logs.get(self.monitor) if logs else None
259+
if current is None:
260+
warnings.warn(
261+
f"Can save best model only with {self.monitor} available, "
262+
f"skipping save at epoch {epoch}.",
263+
stacklevel=2,
264+
)
265+
should_save = False
266+
elif not self._is_improvement(current, self.best):
267+
should_save = False
268+
else:
269+
# Update best value when there's improvement
270+
self.best = current
271+
272+
if should_save:
273+
# Use epoch number as the step for Orbax save
274+
# Keras has already made the save decision - Checkpointer will
275+
# save unconditionally
276+
self._save_checkpoint(step=epoch, logs=logs)
277+
278+
def on_train_end(self, logs=None):
279+
# Close the Checkpointer to ensure all pending saves complete
280+
try:
281+
self.checkpointer.close()
282+
except Exception:
283+
pass # Ignore errors during cleanup
284+
285+
def wait_until_finished(self):
286+
"""Wait for any in-progress checkpoint operations to complete.
287+
This method blocks until all asynchronous checkpoint save operations
288+
have completed. It should be called before attempting to load
289+
checkpoints if there might be pending save operations.
290+
"""
291+
# Wait for any async operations to complete
292+
if hasattr(self.checkpointer, "wait"):
293+
self.checkpointer.wait()
294+
else:
295+
# Fallback for older Orbax versions that don't have wait() method
296+
while self.checkpointer.is_saving_in_progress():
297+
import time
298+
299+
time.sleep(0.1)

0 commit comments

Comments
 (0)