1+ """Packaging utilities for serializing functions, args, and working directories.
2+
3+ Handles zipping the user's working directory, serializing the function
4+ payload with cloudpickle, and extracting/replacing Data objects in
5+ arbitrarily nested arg structures.
6+ """
7+
18import os
29import zipfile
10+ from collections .abc import Callable
11+ from typing import Any
312
413import cloudpickle
514
615from keras_remote .data import Data
716
17+ # Type alias for a position path through nested args, e.g. ("arg", 0, "key").
18+ PositionPath = tuple [str | int , ...]
19+
20+
21+ def zip_working_dir (
22+ base_dir : str , output_path : str , exclude_paths : set [str ] | None = None
23+ ) -> None :
24+ """Zip a directory into a ZIP archive, excluding common non-source files.
825
9- def zip_working_dir (base_dir , output_path , exclude_paths = None ):
10- """Zips the base_dir into output_path, excluding .git, __pycache__,
11- and any paths in exclude_paths."""
26+ Excludes ``.git``, ``__pycache__``, and any paths in *exclude_paths*
27+ (which may be files or directories).
28+
29+ Args:
30+ base_dir: Root directory to zip.
31+ output_path: Destination path for the ZIP file.
32+ exclude_paths: Absolute paths to skip during archiving.
33+ """
1234 exclude_paths = exclude_paths or set ()
1335 normalized_excludes = {os .path .normpath (p ) for p in exclude_paths }
1436
@@ -30,10 +52,28 @@ def zip_working_dir(base_dir, output_path, exclude_paths=None):
3052 zipf .write (file_path , archive_name )
3153
3254
33- def save_payload (func , args , kwargs , env_vars , output_path , volumes = None ):
34- """Uses cloudpickle to serialize the function, args, kwargs, and
35- env_vars."""
36- payload = {
55+ def save_payload (
56+ func : Callable ,
57+ args : tuple ,
58+ kwargs : dict [str , Any ],
59+ env_vars : dict [str , str ],
60+ output_path : str ,
61+ volumes : list [dict [str , Any ]] | None = None ,
62+ ) -> None :
63+ """Serialize a function call payload with cloudpickle.
64+
65+ The resulting pickle file contains a dict with keys ``func``, ``args``,
66+ ``kwargs``, ``env_vars``, and optionally ``volumes``.
67+
68+ Args:
69+ func: The user function to execute remotely.
70+ args: Positional arguments (Data objects should already be replaced).
71+ kwargs: Keyword arguments.
72+ env_vars: Environment variables to set on the remote pod.
73+ output_path: Destination path for the pickle file.
74+ volumes: Optional list of volume data-ref dicts.
75+ """
76+ payload : dict [str , Any ] = {
3777 "func" : func ,
3878 "args" : args ,
3979 "kwargs" : kwargs ,
@@ -45,50 +85,89 @@ def save_payload(func, args, kwargs, env_vars, output_path, volumes=None):
4585 cloudpickle .dump (payload , f )
4686
4787
48- def extract_data_refs (args , kwargs ):
88+ def extract_data_refs (
89+ args : tuple , kwargs : dict [str , Any ]
90+ ) -> list [tuple [Data , PositionPath ]]:
4991 """Scan args and kwargs for Data objects at any nesting depth.
5092
51- Returns list of (data_obj, position_path) tuples.
93+ Returns a list of ``(data_obj, position_path)`` tuples. The position
94+ path encodes where each Data object was found, e.g.
95+ ``("arg", 0)`` or ``("kwarg", "config", "data")``.
96+
97+ Circular references are handled safely via an ``id()``-based visited
98+ set.
5299 """
53- refs = []
100+ refs : list [ tuple [ Data , PositionPath ]] = []
54101 for i , arg in enumerate (args ):
55102 _scan_for_data (arg , ("arg" , i ), refs )
56103 for key , val in kwargs .items ():
57104 _scan_for_data (val , ("kwarg" , key ), refs )
58105 return refs
59106
60107
61- def _scan_for_data (obj , path , refs ):
108+ def _scan_for_data (
109+ obj : Any ,
110+ path : PositionPath ,
111+ refs : list [tuple [Data , PositionPath ]],
112+ visited : set [int ] | None = None ,
113+ ) -> None :
114+ """Recursively collect Data objects from a nested structure."""
115+ if visited is None :
116+ visited = set ()
117+ obj_id = id (obj )
118+ if obj_id in visited :
119+ return
120+ visited .add (obj_id )
62121 if isinstance (obj , Data ):
63122 refs .append ((obj , path ))
64- elif isinstance (obj , (list , tuple )):
123+ elif isinstance (obj , (list , tuple , set , frozenset )):
65124 for i , item in enumerate (obj ):
66- _scan_for_data (item , path + (i ,), refs )
125+ _scan_for_data (item , path + (i ,), refs , visited )
67126 elif isinstance (obj , dict ):
68127 for key , val in obj .items ():
69- _scan_for_data (val , path + (key ,), refs )
128+ _scan_for_data (val , path + (key ,), refs , visited )
70129
71130
72- def replace_data_with_refs (args , kwargs , ref_map ):
73- """Replace Data objects with serializable ref dicts.
131+ def replace_data_with_refs (
132+ args : tuple ,
133+ kwargs : dict [str , Any ],
134+ ref_map : dict [int , dict [str , Any ]],
135+ ) -> tuple [tuple , dict [str , Any ]]:
136+ """Replace Data objects in args/kwargs with serializable ref dicts.
74137
75138 Args:
76- ref_map: dict mapping id(Data) -> ref dict
139+ args: Positional arguments, possibly containing Data objects.
140+ kwargs: Keyword arguments, possibly containing Data objects.
141+ ref_map: Mapping from ``id(Data)`` to the replacement ref dict.
142+
77143 Returns:
78- (new_args, new_kwargs) -- new tuples/dicts with Data replaced
144+ `` (new_args, new_kwargs)`` with all matched Data objects replaced.
79145 """
80146 new_args = tuple (_replace_in_value (a , ref_map ) for a in args )
81147 new_kwargs = {k : _replace_in_value (v , ref_map ) for k , v in kwargs .items ()}
82148 return new_args , new_kwargs
83149
84150
85- def _replace_in_value (obj , ref_map ):
86- if isinstance (obj , Data ) and id (obj ) in ref_map :
87- return ref_map [id (obj )]
151+ def _replace_in_value (
152+ obj : Any ,
153+ ref_map : dict [int , dict [str , Any ]],
154+ visited : set [int ] | None = None ,
155+ ) -> Any :
156+ """Recursively replace Data objects with their ref dicts."""
157+ if visited is None :
158+ visited = set ()
159+ obj_id = id (obj )
160+ if obj_id in visited :
161+ return obj
162+ visited .add (obj_id )
163+ if isinstance (obj , Data ) and obj_id in ref_map :
164+ return ref_map [obj_id ]
88165 elif isinstance (obj , list ):
89- return [_replace_in_value (item , ref_map ) for item in obj ]
166+ return [_replace_in_value (item , ref_map , visited ) for item in obj ]
90167 elif isinstance (obj , tuple ):
91- return tuple (_replace_in_value (item , ref_map ) for item in obj )
168+ return tuple (_replace_in_value (item , ref_map , visited ) for item in obj )
169+ elif isinstance (obj , (set , frozenset )):
170+ return [_replace_in_value (item , ref_map , visited ) for item in obj ]
92171 elif isinstance (obj , dict ):
93- return {k : _replace_in_value (v , ref_map ) for k , v in obj .items ()}
172+ return {k : _replace_in_value (v , ref_map , visited ) for k , v in obj .items ()}
94173 return obj
0 commit comments