Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 82 additions & 8 deletions keras/src/callbacks/orbax_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class OrbaxCheckpoint(MonitorCallback):
This callback saves the model's weights and optimizer state asynchronously
using Orbax, allowing training to continue without blocking for I/O.
**Multi-host Support**: When running in a multi-host distributed training
environment with JAX backend, this callback automatically coordinates
checkpointing across all hosts to ensure consistency and proper
synchronization. Multi-host checkpointing is only supported on JAX.
Example:
```python
Expand Down Expand Up @@ -138,6 +143,9 @@ def __init__(
self._current_epoch = 0 # Keep track of epoch
self._total_batches_seen = 0 # Global batch counter for step tracking

# Multi-host support
self._multihost_initialized = self._is_multihost_initialized()

if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
raise ValueError(
f"Unrecognized save_freq: {self.save_freq}. "
Expand Down Expand Up @@ -167,6 +175,62 @@ def __init__(
preservation_policy=preservation_policy,
)

def _is_multihost_initialized(self):
"""Check if multi-host environment is initialized."""
# Multi-host checkpointing is only supported on JAX backend
if backend.backend() != "jax":
return False

try:
import orbax.checkpoint as ocp

return ocp.multihost.is_initialized()
except (ImportError, AttributeError):
return False

def _is_primary_host(self):
"""Check if this is the primary host for coordination."""
if not self._multihost_initialized:
return True # Single host is always primary
import orbax.checkpoint as ocp

return ocp.multihost.is_primary_host()

def _sync_processes(self, key=None):
"""Synchronize all processes across hosts."""
if not self._multihost_initialized:
return # No-op for single host

import orbax.checkpoint as ocp

sync_key = key or f"checkpoint_sync_{id(self)}"
ocp.multihost.sync_global_processes(sync_key)

def is_multihost_enabled(self):
"""Return True if multi-host checkpointing is enabled and initialized.
This method can be used to check if the callback is operating in
a multi-host distributed training environment. Multi-host checkpointing
is only supported on JAX backend.
Returns:
bool: True if multi-host support is active, False otherwise.
"""
return self._multihost_initialized

def is_primary_host(self):
"""Return True if this process is the primary host in multi-host setup.
In multi-host environments, only the primary host typically handles
logging and coordination tasks. Multi-host checkpointing is only
supported on JAX backend.
Returns:
bool: True if this is the primary host, False otherwise.
Always returns True in single-host environments.
"""
return self._is_primary_host()

def _should_save_on_batch(self, batch):
"""Check if we should save on this batch."""
if self.save_freq == "epoch":
Expand All @@ -186,7 +250,7 @@ def _should_save_on_batch(self, batch):
return False

def _save_checkpoint(self, step, logs=None):
"""Save a checkpoint at the given step."""
"""Save a checkpoint at the given step with multi-host coordination."""

# --- Prepare Composite State (Backend-Agnostic) ---
state_tree = _get_state_tree(self.model)
Expand All @@ -204,11 +268,13 @@ def _save_checkpoint(self, step, logs=None):
else:
composite_state = state_tree

# --- Save Logic (V1 API) ---
# --- Multi-host Coordination ---
# All processes participate in distributed checkpointing
# Checkpointer is configured to save unconditionally when
# save_pytree is called
if self.verbose > 0:
# Synchronize before saving to ensure consistency
self._sync_processes(f"checkpoint_save_start_{step}")

# --- Save Logic (V1 API) ---
if self.verbose > 0 and self._is_primary_host():
print_msg(
f"OrbaxCheckpoint: Triggering async save for step {step}..."
)
Expand All @@ -221,6 +287,9 @@ def _save_checkpoint(self, step, logs=None):
else:
self.checkpointer.save_pytree(step, composite_state)

# Synchronize after saving to ensure all processes complete
self._sync_processes(f"checkpoint_save_end_{step}")

def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
# Handle save_best_only logic for batch-level saving
Expand Down Expand Up @@ -282,13 +351,15 @@ def on_train_end(self, logs=None):
except Exception:
pass # Ignore errors during cleanup

# Multi-host synchronization: ensure all hosts complete cleanup
self._sync_processes("checkpoint_cleanup")

def wait_until_finished(self):
"""Wait for any in-progress checkpoint operations to complete.
This method blocks until all asynchronous checkpoint save operations
have completed. It should be called before attempting to load
checkpoints if there might be pending save operations.
have completed across all hosts in a multi-host setup.
"""
# Wait for any async operations to complete
# Wait for any async operations to complete on this host
if hasattr(self.checkpointer, "wait"):
self.checkpointer.wait()
else:
Expand All @@ -297,3 +368,6 @@ def wait_until_finished(self):
import time

time.sleep(0.1)

# Multi-host synchronization: ensure all hosts complete
self._sync_processes("checkpoint_wait_complete")
Loading