Skip to content

Commit fdb7053

Browse files
Addressed reviews
1 parent 60fc365 commit fdb7053

File tree

2 files changed

+163
-24
lines changed

2 files changed

+163
-24
lines changed

keras_remote/utils/packager.py

Lines changed: 103 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,36 @@
1+
"""Packaging utilities for serializing functions, args, and working directories.
2+
3+
Handles zipping the user's working directory, serializing the function
4+
payload with cloudpickle, and extracting/replacing Data objects in
5+
arbitrarily nested arg structures.
6+
"""
7+
18
import os
29
import zipfile
10+
from collections.abc import Callable
11+
from typing import Any
312

413
import cloudpickle
514

615
from keras_remote.data import Data
716

17+
# Type alias for a position path through nested args, e.g. ("arg", 0, "key").
18+
PositionPath = tuple[str | int, ...]
19+
20+
21+
def zip_working_dir(
22+
base_dir: str, output_path: str, exclude_paths: set[str] | None = None
23+
) -> None:
24+
"""Zip a directory into a ZIP archive, excluding common non-source files.
825
9-
def zip_working_dir(base_dir, output_path, exclude_paths=None):
10-
"""Zips the base_dir into output_path, excluding .git, __pycache__,
11-
and any paths in exclude_paths."""
26+
Excludes ``.git``, ``__pycache__``, and any paths in *exclude_paths*
27+
(which may be files or directories).
28+
29+
Args:
30+
base_dir: Root directory to zip.
31+
output_path: Destination path for the ZIP file.
32+
exclude_paths: Absolute paths to skip during archiving.
33+
"""
1234
exclude_paths = exclude_paths or set()
1335
normalized_excludes = {os.path.normpath(p) for p in exclude_paths}
1436

@@ -30,10 +52,28 @@ def zip_working_dir(base_dir, output_path, exclude_paths=None):
3052
zipf.write(file_path, archive_name)
3153

3254

33-
def save_payload(func, args, kwargs, env_vars, output_path, volumes=None):
34-
"""Uses cloudpickle to serialize the function, args, kwargs, and
35-
env_vars."""
36-
payload = {
55+
def save_payload(
56+
func: Callable,
57+
args: tuple,
58+
kwargs: dict[str, Any],
59+
env_vars: dict[str, str],
60+
output_path: str,
61+
volumes: list[dict[str, Any]] | None = None,
62+
) -> None:
63+
"""Serialize a function call payload with cloudpickle.
64+
65+
The resulting pickle file contains a dict with keys ``func``, ``args``,
66+
``kwargs``, ``env_vars``, and optionally ``volumes``.
67+
68+
Args:
69+
func: The user function to execute remotely.
70+
args: Positional arguments (Data objects should already be replaced).
71+
kwargs: Keyword arguments.
72+
env_vars: Environment variables to set on the remote pod.
73+
output_path: Destination path for the pickle file.
74+
volumes: Optional list of volume data-ref dicts.
75+
"""
76+
payload: dict[str, Any] = {
3777
"func": func,
3878
"args": args,
3979
"kwargs": kwargs,
@@ -45,50 +85,89 @@ def save_payload(func, args, kwargs, env_vars, output_path, volumes=None):
4585
cloudpickle.dump(payload, f)
4686

4787

48-
def extract_data_refs(args, kwargs):
88+
def extract_data_refs(
89+
args: tuple, kwargs: dict[str, Any]
90+
) -> list[tuple[Data, PositionPath]]:
4991
"""Scan args and kwargs for Data objects at any nesting depth.
5092
51-
Returns list of (data_obj, position_path) tuples.
93+
Returns a list of ``(data_obj, position_path)`` tuples. The position
94+
path encodes where each Data object was found, e.g.
95+
``("arg", 0)`` or ``("kwarg", "config", "data")``.
96+
97+
Circular references are handled safely via an ``id()``-based visited
98+
set.
5299
"""
53-
refs = []
100+
refs: list[tuple[Data, PositionPath]] = []
54101
for i, arg in enumerate(args):
55102
_scan_for_data(arg, ("arg", i), refs)
56103
for key, val in kwargs.items():
57104
_scan_for_data(val, ("kwarg", key), refs)
58105
return refs
59106

60107

61-
def _scan_for_data(obj, path, refs):
108+
def _scan_for_data(
109+
obj: Any,
110+
path: PositionPath,
111+
refs: list[tuple[Data, PositionPath]],
112+
visited: set[int] | None = None,
113+
) -> None:
114+
"""Recursively collect Data objects from a nested structure."""
115+
if visited is None:
116+
visited = set()
117+
obj_id = id(obj)
118+
if obj_id in visited:
119+
return
120+
visited.add(obj_id)
62121
if isinstance(obj, Data):
63122
refs.append((obj, path))
64-
elif isinstance(obj, (list, tuple)):
123+
elif isinstance(obj, (list, tuple, set, frozenset)):
65124
for i, item in enumerate(obj):
66-
_scan_for_data(item, path + (i,), refs)
125+
_scan_for_data(item, path + (i,), refs, visited)
67126
elif isinstance(obj, dict):
68127
for key, val in obj.items():
69-
_scan_for_data(val, path + (key,), refs)
128+
_scan_for_data(val, path + (key,), refs, visited)
70129

71130

72-
def replace_data_with_refs(args, kwargs, ref_map):
73-
"""Replace Data objects with serializable ref dicts.
131+
def replace_data_with_refs(
132+
args: tuple,
133+
kwargs: dict[str, Any],
134+
ref_map: dict[int, dict[str, Any]],
135+
) -> tuple[tuple, dict[str, Any]]:
136+
"""Replace Data objects in args/kwargs with serializable ref dicts.
74137
75138
Args:
76-
ref_map: dict mapping id(Data) -> ref dict
139+
args: Positional arguments, possibly containing Data objects.
140+
kwargs: Keyword arguments, possibly containing Data objects.
141+
ref_map: Mapping from ``id(Data)`` to the replacement ref dict.
142+
77143
Returns:
78-
(new_args, new_kwargs) -- new tuples/dicts with Data replaced
144+
``(new_args, new_kwargs)`` with all matched Data objects replaced.
79145
"""
80146
new_args = tuple(_replace_in_value(a, ref_map) for a in args)
81147
new_kwargs = {k: _replace_in_value(v, ref_map) for k, v in kwargs.items()}
82148
return new_args, new_kwargs
83149

84150

85-
def _replace_in_value(obj, ref_map):
86-
if isinstance(obj, Data) and id(obj) in ref_map:
87-
return ref_map[id(obj)]
151+
def _replace_in_value(
152+
obj: Any,
153+
ref_map: dict[int, dict[str, Any]],
154+
visited: set[int] | None = None,
155+
) -> Any:
156+
"""Recursively replace Data objects with their ref dicts."""
157+
if visited is None:
158+
visited = set()
159+
obj_id = id(obj)
160+
if obj_id in visited:
161+
return obj
162+
visited.add(obj_id)
163+
if isinstance(obj, Data) and obj_id in ref_map:
164+
return ref_map[obj_id]
88165
elif isinstance(obj, list):
89-
return [_replace_in_value(item, ref_map) for item in obj]
166+
return [_replace_in_value(item, ref_map, visited) for item in obj]
90167
elif isinstance(obj, tuple):
91-
return tuple(_replace_in_value(item, ref_map) for item in obj)
168+
return tuple(_replace_in_value(item, ref_map, visited) for item in obj)
169+
elif isinstance(obj, (set, frozenset)):
170+
return [_replace_in_value(item, ref_map, visited) for item in obj]
92171
elif isinstance(obj, dict):
93-
return {k: _replace_in_value(v, ref_map) for k, v in obj.items()}
172+
return {k: _replace_in_value(v, ref_map, visited) for k, v in obj.items()}
94173
return obj

keras_remote/utils/packager_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,34 @@ def test_no_data_objects(self):
312312
refs = extract_data_refs((1, "hello"), {"lr": 0.01})
313313
self.assertEqual(len(refs), 0)
314314

315+
def test_nested_in_set(self):
316+
tmp = _make_temp_path(self)
317+
f = tmp / "data.csv"
318+
f.write_text("data")
319+
d = Data(str(f))
320+
321+
refs = extract_data_refs(({d, "other"},), {})
322+
self.assertEqual(len(refs), 1)
323+
self.assertIs(refs[0][0], d)
324+
325+
def test_nested_in_frozenset(self):
326+
tmp = _make_temp_path(self)
327+
f = tmp / "data.csv"
328+
f.write_text("data")
329+
d = Data(str(f))
330+
331+
refs = extract_data_refs((frozenset({d}),), {})
332+
self.assertEqual(len(refs), 1)
333+
self.assertIs(refs[0][0], d)
334+
335+
def test_circular_reference_does_not_recurse(self):
336+
"""Circular structures in args should not cause infinite recursion."""
337+
circular = {"key": "value"}
338+
circular["self"] = circular
339+
340+
refs = extract_data_refs((circular,), {})
341+
self.assertEqual(len(refs), 0)
342+
315343

316344
class TestReplaceDataWithRefs(absltest.TestCase):
317345
def test_replaces_direct_arg(self):
@@ -357,6 +385,38 @@ def test_preserves_non_data(self):
357385
self.assertEqual(new_args, (1, "hello", [1, 2]))
358386
self.assertEqual(new_kwargs, {"x": 3})
359387

388+
def test_replaces_in_set(self):
389+
tmp = _make_temp_path(self)
390+
f = tmp / "data.csv"
391+
f.write_text("data")
392+
d = Data(str(f))
393+
ref = {"__data_ref__": True, "gcs_uri": "gs://b/p"}
394+
ref_map = {id(d): ref}
395+
396+
new_args, _ = replace_data_with_refs(({d},), {}, ref_map)
397+
self.assertIsInstance(new_args[0], list)
398+
self.assertIn(ref, new_args[0])
399+
400+
def test_replaces_in_frozenset(self):
401+
tmp = _make_temp_path(self)
402+
f = tmp / "data.csv"
403+
f.write_text("data")
404+
d = Data(str(f))
405+
ref = {"__data_ref__": True, "gcs_uri": "gs://b/p"}
406+
ref_map = {id(d): ref}
407+
408+
new_args, _ = replace_data_with_refs((frozenset({d}),), {}, ref_map)
409+
self.assertIsInstance(new_args[0], list)
410+
self.assertIn(ref, new_args[0])
411+
412+
def test_circular_reference_does_not_recurse(self):
413+
"""Circular structures should not cause infinite recursion."""
414+
circular = [1, 2]
415+
circular.append(circular)
416+
417+
new_args, _ = replace_data_with_refs((circular,), {}, {})
418+
self.assertIsInstance(new_args[0], list)
419+
360420

361421
if __name__ == "__main__":
362422
absltest.main()

0 commit comments

Comments
 (0)