Skip to content

Commit a4569fe

Browse files
Integrate Data API into execution pipeline (#66)
1 parent 7268485 commit a4569fe

File tree

7 files changed

+915
-46
lines changed

7 files changed

+915
-46
lines changed

keras_remote/backend/execution.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from keras_remote.backend import gke_client, pathways_client
1919
from keras_remote.constants import get_default_zone, zone_to_region
2020
from keras_remote.credentials import ensure_credentials
21+
from keras_remote.data import _make_data_ref
2122
from keras_remote.infra import container_builder
2223
from keras_remote.infra.infra import get_default_project
2324
from keras_remote.utils import packager, storage
@@ -47,6 +48,9 @@ class JobContext:
4748
region: str = field(init=False)
4849
display_name: str = field(init=False)
4950

51+
# Data volumes {mount_path: Data}
52+
volumes: Optional[dict] = None
53+
5054
# Artifact paths (set during prepare phase)
5155
payload_path: Optional[str] = None
5256
context_path: Optional[str] = None
@@ -69,6 +73,7 @@ def from_params(
6973
zone: Optional[str],
7074
project: Optional[str],
7175
env_vars: dict,
76+
volumes: Optional[dict] = None,
7277
) -> "JobContext":
7378
"""Factory method with default resolution for zone/project."""
7479
if not zone:
@@ -90,13 +95,14 @@ def from_params(
9095
container_image=container_image,
9196
zone=zone,
9297
project=project,
98+
volumes=volumes,
9399
)
94100

95101

96102
class BaseK8sBackend:
97103
"""Base class for Kubernetes-based backends."""
98104

99-
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
105+
def __init__(self, cluster: str, namespace: str = "default"):
100106
self.cluster = cluster
101107
self.namespace = namespace
102108

@@ -202,6 +208,14 @@ def _find_requirements(start_dir: str) -> Optional[str]:
202208
return None
203209

204210

211+
def _maybe_exclude(data_path, caller_path, exclude_paths):
212+
"""Add data_path to exclude_paths if it's inside the caller directory."""
213+
data_abs = os.path.normpath(data_path)
214+
caller_abs = os.path.normpath(caller_path)
215+
if data_abs.startswith(caller_abs + os.sep) or data_abs == caller_abs:
216+
exclude_paths.add(data_abs)
217+
218+
205219
def _prepare_artifacts(
206220
ctx: JobContext, tmpdir: str, caller_frame_depth: int = 3
207221
) -> None:
@@ -211,21 +225,58 @@ def _prepare_artifacts(
211225
# Get caller directory
212226
frame = inspect.stack()[caller_frame_depth]
213227
module = inspect.getmodule(frame[0])
214-
if module:
228+
caller_path: str
229+
if module and module.__file__:
215230
caller_path = os.path.dirname(os.path.abspath(module.__file__))
216231
else:
217232
caller_path = os.getcwd()
218233

219-
# Serialize function + args
234+
# Process Data objects
235+
exclude_paths: set[str] = set()
236+
ref_map = {} # id(Data) -> ref dict (for arg replacement)
237+
volume_refs = [] # list of ref dicts (for volumes)
238+
239+
# Process volumes
240+
if ctx.volumes:
241+
for mount_path, data_obj in ctx.volumes.items():
242+
gcs_uri = storage.upload_data(ctx.bucket_name, data_obj, ctx.project)
243+
volume_refs.append(
244+
_make_data_ref(gcs_uri, data_obj.is_dir, mount_path=mount_path)
245+
)
246+
if not data_obj.is_gcs:
247+
_maybe_exclude(data_obj.path, caller_path, exclude_paths)
248+
249+
# Process Data in function args
250+
data_refs = packager.extract_data_refs(ctx.args, ctx.kwargs)
251+
for data_obj, _position in data_refs:
252+
gcs_uri = storage.upload_data(ctx.bucket_name, data_obj, ctx.project)
253+
ref_map[id(data_obj)] = _make_data_ref(gcs_uri, data_obj.is_dir)
254+
if not data_obj.is_gcs:
255+
_maybe_exclude(data_obj.path, caller_path, exclude_paths)
256+
257+
# Replace Data with refs in args/kwargs
258+
if ref_map:
259+
ctx.args, ctx.kwargs = packager.replace_data_with_refs(
260+
ctx.args, ctx.kwargs, ref_map
261+
)
262+
263+
# Serialize function + args (with volume refs)
220264
ctx.payload_path = os.path.join(tmpdir, "payload.pkl")
221265
packager.save_payload(
222-
ctx.func, ctx.args, ctx.kwargs, ctx.env_vars, ctx.payload_path
266+
ctx.func,
267+
ctx.args,
268+
ctx.kwargs,
269+
ctx.env_vars,
270+
ctx.payload_path,
271+
volumes=volume_refs or None,
223272
)
224273
logging.info("Payload serialized to %s", ctx.payload_path)
225274

226-
# Zip working directory
275+
# Zip working directory (excluding Data paths)
227276
ctx.context_path = os.path.join(tmpdir, "context.zip")
228-
packager.zip_working_dir(caller_path, ctx.context_path)
277+
packager.zip_working_dir(
278+
caller_path, ctx.context_path, exclude_paths=exclude_paths
279+
)
229280
logging.info("Context packaged to %s", ctx.context_path)
230281

231282
# Find requirements.txt
@@ -257,6 +308,8 @@ def _build_container(ctx: JobContext) -> None:
257308

258309
def _upload_artifacts(ctx: JobContext) -> None:
259310
"""Phase 3: Upload artifacts to Cloud Storage."""
311+
if ctx.payload_path is None or ctx.context_path is None:
312+
raise ValueError("payload_path and context_path must be set before upload")
260313
logging.info("Uploading artifacts to Cloud Storage (job: %s)...", ctx.job_id)
261314
storage.upload_artifacts(
262315
bucket_name=ctx.bucket_name,

keras_remote/core/core.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from keras_remote.constants import DEFAULT_CLUSTER_NAME
1111
from keras_remote.core import accelerators
12+
from keras_remote.data import Data
1213

1314

1415
def run(
@@ -20,6 +21,7 @@ def run(
2021
cluster=None,
2122
backend=None,
2223
namespace="default",
24+
volumes=None,
2325
):
2426
"""Execute function on remote TPU/GPU.
2527
@@ -33,7 +35,25 @@ def run(
3335
cluster: GKE cluster name (default: from KERAS_REMOTE_CLUSTER)
3436
backend: Backend to use ('gke' or 'pathways')
3537
namespace: Kubernetes namespace (default: 'default')
38+
volumes: Dict mapping absolute mount paths to Data objects, e.g.
39+
``{"/data": Data("./dataset/")}``. Data is downloaded to these
40+
paths on the pod before function execution.
3641
"""
42+
# Validate volumes
43+
if volumes is not None:
44+
if not isinstance(volumes, dict):
45+
raise TypeError(f"volumes must be a dict, got {type(volumes).__name__}")
46+
for mount_path, data_obj in volumes.items():
47+
if not isinstance(mount_path, str) or not mount_path.startswith("/"):
48+
raise ValueError(
49+
f"Volume mount path must be an absolute path "
50+
f"(start with '/'), got: {mount_path!r}"
51+
)
52+
if not isinstance(data_obj, Data):
53+
raise TypeError(
54+
f"Volume value for {mount_path!r} must be a Data "
55+
f"instance, got {type(data_obj).__name__}"
56+
)
3757

3858
def decorator(func):
3959
@functools.wraps(func)
@@ -79,6 +99,7 @@ def wrapper(*args, **kwargs):
7999
cluster,
80100
namespace,
81101
env_vars,
102+
volumes,
82103
)
83104
elif resolved_backend == "pathways":
84105
return _execute_on_pathways(
@@ -92,6 +113,7 @@ def wrapper(*args, **kwargs):
92113
cluster,
93114
namespace,
94115
env_vars,
116+
volumes,
95117
)
96118
else:
97119
raise ValueError(
@@ -114,6 +136,7 @@ def _execute_on_gke(
114136
cluster,
115137
namespace,
116138
env_vars,
139+
volumes,
117140
):
118141
"""Execute function on GKE cluster with GPU/TPU nodes."""
119142
# Get GKE-specific defaults
@@ -123,7 +146,15 @@ def _execute_on_gke(
123146
namespace = os.environ.get("KERAS_REMOTE_GKE_NAMESPACE", "default")
124147

125148
ctx = JobContext.from_params(
126-
func, args, kwargs, accelerator, container_image, zone, project, env_vars
149+
func,
150+
args,
151+
kwargs,
152+
accelerator,
153+
container_image,
154+
zone,
155+
project,
156+
env_vars,
157+
volumes=volumes,
127158
)
128159
return execute_remote(ctx, GKEBackend(cluster=cluster, namespace=namespace))
129160

@@ -139,6 +170,7 @@ def _execute_on_pathways(
139170
cluster,
140171
namespace,
141172
env_vars,
173+
volumes,
142174
):
143175
"""Execute function on GKE cluster via ML Pathways."""
144176
if not cluster:
@@ -147,7 +179,15 @@ def _execute_on_pathways(
147179
namespace = os.environ.get("KERAS_REMOTE_GKE_NAMESPACE", "default")
148180

149181
ctx = JobContext.from_params(
150-
func, args, kwargs, accelerator, container_image, zone, project, env_vars
182+
func,
183+
args,
184+
kwargs,
185+
accelerator,
186+
container_image,
187+
zone,
188+
project,
189+
env_vars,
190+
volumes=volumes,
151191
)
152192
return execute_remote(
153193
ctx, PathwaysBackend(cluster=cluster, namespace=namespace)

keras_remote/core/core_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def func():
3333

3434
func()
3535
call_args = mock_exec.call_args
36-
env_vars = call_args[0][-1] # last positional arg
36+
env_vars = call_args[0][-2] # last positional arg
3737
self.assertEqual(env_vars, {"MY_VAR": "my_val"})
3838

3939
def test_wildcard_pattern(self):
@@ -53,7 +53,7 @@ def func():
5353
pass
5454

5555
func()
56-
env_vars = mock_exec.call_args[0][-1]
56+
env_vars = mock_exec.call_args[0][-2]
5757
self.assertIn("PREFIX_A", env_vars)
5858
self.assertIn("PREFIX_B", env_vars)
5959
self.assertNotIn("OTHER", env_vars)
@@ -70,7 +70,7 @@ def func():
7070
pass
7171

7272
func()
73-
env_vars = mock_exec.call_args[0][-1]
73+
env_vars = mock_exec.call_args[0][-2]
7474
self.assertEqual(env_vars, {})
7575

7676
def test_none_capture(self):
@@ -81,7 +81,7 @@ def func():
8181
pass
8282

8383
func()
84-
env_vars = mock_exec.call_args[0][-1]
84+
env_vars = mock_exec.call_args[0][-2]
8585
self.assertEqual(env_vars, {})
8686

8787
def test_mixed_exact_and_wildcard(self):
@@ -104,7 +104,7 @@ def func():
104104
pass
105105

106106
func()
107-
env_vars = mock_exec.call_args[0][-1]
107+
env_vars = mock_exec.call_args[0][-2]
108108
self.assertEqual(
109109
env_vars, {"EXACT_VAR": "exact", "WILD_A": "a", "WILD_B": "b"}
110110
)

keras_remote/runner/remote_runner.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
# Base temp directory for remote execution artifacts
2020
TEMP_DIR = tempfile.gettempdir()
21+
DATA_DIR = os.path.join(TEMP_DIR, "data")
2122

2223

2324
def main():
@@ -87,6 +88,12 @@ def run_gcs_mode():
8788
logging.info("Setting %d environment variables", len(env_vars))
8889
os.environ.update(env_vars)
8990

91+
# Resolve Data references
92+
volumes = payload.get("volumes", [])
93+
if volumes:
94+
resolve_volumes(volumes, storage_client)
95+
args, kwargs = resolve_data_refs(args, kwargs, storage_client)
96+
9097
# Execute function and capture result
9198
logging.info("Executing %s()", func.__name__)
9299
result = None
@@ -129,6 +136,70 @@ def run_gcs_mode():
129136
sys.exit(1)
130137

131138

139+
def resolve_volumes(volume_refs, storage_client):
140+
"""Download volume data to their specified mount paths."""
141+
for ref in volume_refs:
142+
mount_path = ref["mount_path"]
143+
logging.info("Resolving volume: %s -> %s", ref["gcs_uri"], mount_path)
144+
_download_data(ref, mount_path, storage_client)
145+
146+
147+
def resolve_data_refs(args, kwargs, storage_client):
148+
"""Recursively resolve data ref dicts in args/kwargs to local paths."""
149+
counter = 0
150+
151+
def _resolve(obj):
152+
nonlocal counter
153+
# Data ref that needs downloading (no mount_path means not volume-mounted)
154+
if isinstance(obj, dict) and obj.get("__data_ref__"):
155+
# Volume-mounted data refs are handled by Kubernetes, skip download
156+
if obj.get("mount_path") is not None:
157+
return obj["mount_path"]
158+
local_dir = os.path.join(DATA_DIR, str(counter))
159+
counter += 1
160+
_download_data(obj, local_dir, storage_client)
161+
# Return file path for single files, directory path otherwise
162+
if not obj["is_dir"]:
163+
files = [f for f in os.listdir(local_dir) if f != ".cache_marker"]
164+
if len(files) == 1:
165+
return os.path.join(local_dir, files[0])
166+
return local_dir
167+
# Recurse into containers to find nested data refs
168+
if isinstance(obj, dict):
169+
return {k: _resolve(v) for k, v in obj.items()}
170+
if isinstance(obj, (list, tuple)):
171+
return type(obj)(_resolve(item) for item in obj)
172+
return obj
173+
174+
resolved_args = tuple(_resolve(a) for a in args)
175+
resolved_kwargs = {k: _resolve(v) for k, v in kwargs.items()}
176+
return resolved_args, resolved_kwargs
177+
178+
179+
def _download_data(ref, target_dir, storage_client):
180+
"""Download data from a GCS URI to a local directory."""
181+
os.makedirs(target_dir, exist_ok=True)
182+
gcs_uri = ref["gcs_uri"]
183+
184+
parts = gcs_uri.replace("gs://", "").split("/", 1)
185+
bucket_name = parts[0]
186+
prefix = parts[1].rstrip("/") if len(parts) > 1 else ""
187+
bucket = storage_client.bucket(bucket_name)
188+
189+
blobs = bucket.list_blobs(prefix=prefix + "/")
190+
count = 0
191+
for blob in blobs:
192+
if blob.name.endswith("/") or blob.name.endswith(".cache_marker"):
193+
continue
194+
rel_path = blob.name[len(prefix) + 1 :]
195+
local_path = os.path.join(target_dir, rel_path)
196+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
197+
blob.download_to_filename(local_path)
198+
count += 1
199+
200+
logging.info("Downloaded %d files from %s to %s", count, gcs_uri, target_dir)
201+
202+
132203
def _download_from_gcs(client, gcs_path, local_path):
133204
"""Download file from GCS.
134205

0 commit comments

Comments
 (0)