Skip to content

Commit 09c7847

Browse files
author
Orbax Authors
committed
#p2p Implement CheckpointManager.reload() properly
PiperOrigin-RevId: 876307566
1 parent 6f4d675 commit 09c7847

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Composite Checkpoint Manager handling P2P syncing with optional Persistent Fallback."""
15+
"""Composite Checkpoint Manager with P2P syncing and Persistent Fallback."""
1616

1717
import shutil
1818
import threading
@@ -178,7 +178,7 @@ def get_all_steps_from_peers(self) -> list[int]:
178178
return self._peer_selector.get_all_steps()
179179

180180
def has_shard_for_step(self, step: int) -> bool:
181-
"""Checks if this process's shard for a given step exists in the P2P network."""
181+
"""Checks if this process's shard for a step exists in P2P."""
182182
assert self._peer_selector is not None
183183
return (
184184
self._peer_selector.get_source_peer(step, self._process_index)
@@ -212,7 +212,7 @@ def close(self):
212212
class CheckpointManager(
213213
abstract_checkpoint_manager.AbstractCheckpointManager, epy.ContextManager
214214
):
215-
"""Orchestrates P2P local checkpointing with optional persistent storage failover.
215+
"""P2P local checkpointing with persistent storage failover.
216216
217217
Restoration Strategy:
218218
1. Check Local Disk (Fastest)
@@ -333,7 +333,15 @@ def best_step(self) -> int | None:
333333

334334
@override
335335
def reload(self):
336+
"""Reloads the checkpoint manager and its components.
337+
338+
This method refreshes the local and persistent managers and marks the P2P
339+
registry as stale, forcing a re-sync on the next access.
340+
"""
341+
self._p2p.mark_registry_stale()
336342
self._local_manager.reload()
343+
if self._persistent_manager:
344+
self._persistent_manager.reload()
337345

338346
@override
339347
def reached_preemption(self, step: int) -> bool:

checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,5 +215,9 @@ def wait_until_finished(self):
215215
def check_for_errors(self):
216216
self._manager.check_for_errors()
217217

218+
def reload(self):
219+
"""Reloads the internal checkpoint manager."""
220+
self._manager.reload()
221+
218222
def close(self):
219223
self._manager.close()

0 commit comments

Comments
 (0)