Skip to content

Commit a9a5260

Browse files
author
Orbax Authors
committed
#p2p Implement CheckpointManager.reload() properly
PiperOrigin-RevId: 876307566
1 parent b192e7f commit a9a5260

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/checkpoint_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Composite Checkpoint Manager handling P2P syncing with optional Persistent Fallback."""
15+
"""Composite Checkpoint Manager with P2P syncing and Persistent Fallback."""
1616

1717
import shutil
1818
import threading
@@ -178,7 +178,7 @@ def get_all_steps_from_peers(self) -> list[int]:
178178
return self._peer_selector.get_all_steps()
179179

180180
def has_shard_for_step(self, step: int) -> bool:
181-
"""Checks if this process's shard for a given step exists in the P2P network."""
181+
"""Checks if this process's shard for a step exists in P2P."""
182182
assert self._peer_selector is not None
183183
return (
184184
self._peer_selector.get_source_peer(step, self._process_index)
@@ -212,7 +212,7 @@ def close(self):
212212
class CheckpointManager(
213213
abstract_checkpoint_manager.AbstractCheckpointManager, epy.ContextManager
214214
):
215-
"""Orchestrates P2P local checkpointing with optional persistent storage failover.
215+
"""P2P local checkpointing with persistent storage failover.
216216
217217
Restoration Strategy:
218218
1. Check Local Disk (Fastest)
@@ -333,7 +333,15 @@ def best_step(self) -> int | None:
333333

334334
@override
335335
def reload(self):
336+
"""Reloads the checkpoint manager and its components.
337+
338+
This method refreshes the local and persistent managers and marks the P2P
339+
registry as stale, forcing a re-sync on the next access.
340+
"""
341+
self._p2p.mark_registry_stale()
336342
self._local_manager.reload()
343+
if self._persistent_manager:
344+
self._persistent_manager.reload()
337345

338346
@override
339347
def reached_preemption(self, step: int) -> bool:

checkpoint/orbax/checkpoint/experimental/emergency/p2p/persistent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,5 +215,9 @@ def wait_until_finished(self):
215215
def check_for_errors(self):
216216
self._manager.check_for_errors()
217217

218+
def reload(self):
219+
"""Reloads the internal checkpoint manager."""
220+
self._manager.reload()
221+
218222
def close(self):
219223
self._manager.close()

checkpoint/orbax/checkpoint/experimental/emergency/p2p/protocol.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import socket
2020
import struct
21+
import typing
2122
from typing import Any, Final
2223
from absl import logging
2324
from etils import epath
@@ -60,6 +61,16 @@ def from_dict(cls, data: dict[str, Any]) -> Self:
6061
)
6162

6263

64+
class ManifestEntry(typing.TypedDict):
65+
"""Type definition for a single file entry in a manifest."""
66+
67+
rel_path: str
68+
size: int
69+
70+
71+
Manifest = list[ManifestEntry]
72+
73+
6374
def optimize_socket(sock: socket.socket) -> None:
6475
try:
6576
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def stop(self):
143143
self._thread.join(timeout=2.0)
144144
self._thread = None
145145

146-
def handle_get_manifest(
147-
self, payload: dict[str, Any]
148-
) -> list[dict[str, Any]]:
146+
def handle_get_manifest(self, payload: dict[str, Any]) -> protocol.Manifest:
149147
"""Handles GET_MANIFEST request.
150148
151149
Args:
@@ -194,6 +192,7 @@ def handle_get_manifest(
194192
step,
195193
req_process_index,
196194
)
195+
197196
return files
198197

199198
def handle_download(self, sock, payload: dict[str, Any]):
@@ -237,7 +236,7 @@ def fetch_shard_from_peer(
237236
"""
238237
logging.info('Requesting manifest from %s:%d for step %d', ip, port, step)
239238

240-
manifest = protocol.TCPClient.request(
239+
manifest: protocol.Manifest = protocol.TCPClient.request(
241240
ip,
242241
port,
243242
protocol.OP_GET_MANIFEST,

0 commit comments

Comments
 (0)