Skip to content

Commit 6f4d675

Browse files
author
Orbax Authors
committed
#p2p Refactor service and protocol for better type safety
PiperOrigin-RevId: 876445289
1 parent 76398a9 commit 6f4d675

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/p2p/protocol.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import socket
2020
import struct
21+
import typing
2122
from typing import Any, Final
2223
from absl import logging
2324
from etils import epath
@@ -60,6 +61,16 @@ def from_dict(cls, data: dict[str, Any]) -> Self:
6061
)
6162

6263

64+
class ManifestEntry(typing.TypedDict):
65+
"""Type definition for a single file entry in a manifest."""
66+
67+
rel_path: str
68+
size: int
69+
70+
71+
Manifest = list[ManifestEntry]
72+
73+
6374
def optimize_socket(sock: socket.socket) -> None:
6475
try:
6576
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)

checkpoint/orbax/checkpoint/experimental/emergency/p2p/service.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,7 @@ def stop(self):
143143
self._thread.join(timeout=2.0)
144144
self._thread = None
145145

146-
def handle_get_manifest(
147-
self, payload: dict[str, Any]
148-
) -> list[dict[str, Any]]:
146+
def handle_get_manifest(self, payload: dict[str, Any]) -> protocol.Manifest:
149147
"""Handles GET_MANIFEST request.
150148
151149
Args:
@@ -194,6 +192,7 @@ def handle_get_manifest(
194192
step,
195193
req_process_index,
196194
)
195+
197196
return files
198197

199198
def handle_download(self, sock, payload: dict[str, Any]):
@@ -237,7 +236,7 @@ def fetch_shard_from_peer(
237236
"""
238237
logging.info('Requesting manifest from %s:%d for step %d', ip, port, step)
239238

240-
manifest = protocol.TCPClient.request(
239+
manifest: protocol.Manifest = protocol.TCPClient.request(
241240
ip,
242241
port,
243242
protocol.OP_GET_MANIFEST,

0 commit comments

Comments
 (0)