Skip to content

Commit 7268485

Browse files
Add Data class and content-addressed storage upload (#65)
1 parent 22252fd commit 7268485

File tree

9 files changed

+724
-37
lines changed

9 files changed

+724
-37
lines changed

keras_remote/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "0")
77

88
from keras_remote.core.core import run as run
9+
from keras_remote.data import Data as Data

keras_remote/backend/execution.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from keras_remote.constants import get_default_zone, zone_to_region
2020
from keras_remote.credentials import ensure_credentials
2121
from keras_remote.infra import container_builder
22+
from keras_remote.infra.infra import get_default_project
2223
from keras_remote.utils import packager, storage
2324

2425

@@ -73,9 +74,7 @@ def from_params(
7374
if not zone:
7475
zone = get_default_zone()
7576
if not project:
76-
project = os.environ.get("KERAS_REMOTE_PROJECT") or os.environ.get(
77-
"GOOGLE_CLOUD_PROJECT"
78-
)
77+
project = get_default_project()
7978
if not project:
8079
raise ValueError(
8180
"project must be specified or set KERAS_REMOTE_PROJECT"

keras_remote/backend/pathways_client.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import time
44

5+
from absl import logging
56
from kubernetes import client
67
from kubernetes.client.rest import ApiException
78

@@ -13,9 +14,6 @@
1314
)
1415
from keras_remote.backend.log_streaming import LogStreamer
1516
from keras_remote.core import accelerators
16-
from keras_remote.infra import infra
17-
18-
logger = infra.logger
1917

2018
LWS_GROUP = "leaderworkerset.x-k8s.io"
2119
LWS_VERSION = "v1"
@@ -40,7 +38,7 @@ def _get_lws_version(group=LWS_GROUP):
4038
# If we didn't find the group, raise ApiException to fallback
4139
raise ApiException(status=404, reason=f"API group {group} not found")
4240
except ApiException:
43-
logger.warning(
41+
logging.warning(
4442
"Failed to retrieve LWS API version from cluster. Defaulting to '%s'",
4543
LWS_VERSION,
4644
)
@@ -108,8 +106,8 @@ def submit_pathways_job(
108106
plural=LWS_PLURAL,
109107
body=lws_manifest,
110108
)
111-
logger.info(f"Submitted Pathways job (LWS): {job_name}")
112-
logger.info(
109+
logging.info(f"Submitted Pathways job (LWS): {job_name}")
110+
logging.info(
113111
"View job with: kubectl get %s %s -n %s", LWS_PLURAL, job_name, namespace
114112
)
115113
return created_lws
@@ -150,11 +148,11 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
150148
try:
151149
pod = core_v1.read_namespaced_pod(leader_pod_name, namespace)
152150
if not logged_running:
153-
logger.info(f"Found pod: {leader_pod_name}")
151+
logging.info(f"Found pod: {leader_pod_name}")
154152
logged_running = True
155153

156154
if pod.status.phase == "Succeeded":
157-
logger.info(f"[REMOTE] Job {job_name} completed successfully")
155+
logging.info(f"[REMOTE] Job {job_name} completed successfully")
158156
return "success"
159157

160158
if pod.status.phase == "Failed":
@@ -163,7 +161,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
163161

164162
elif pod.status.phase == "Pending":
165163
_check_pod_scheduling(core_v1, job_name, namespace)
166-
logger.debug("Pod is Pending...")
164+
logging.debug("Pod is Pending...")
167165

168166
elif pod.status.phase == "Running":
169167
streamer.start(leader_pod_name)
@@ -183,7 +181,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
183181
# Check current state
184182
if container_status.state.terminated:
185183
if container_status.state.terminated.exit_code == 0:
186-
logger.info(f"[REMOTE] Job {job_name} completed successfully")
184+
logging.info(f"[REMOTE] Job {job_name} completed successfully")
187185
return "success"
188186
else:
189187
_print_pod_logs(core_v1, job_name, namespace)
@@ -195,7 +193,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
195193
# Check last state (in case it restarted)
196194
if container_status.last_state.terminated:
197195
if container_status.last_state.terminated.exit_code == 0:
198-
logger.info(
196+
logging.info(
199197
f"[REMOTE] Job {job_name} completed successfully (restarted)"
200198
)
201199
return "success"
@@ -223,13 +221,13 @@ def cleanup_job(job_name, namespace="default"):
223221
plural=LWS_PLURAL,
224222
name=job_name,
225223
)
226-
logger.info(f"Deleted LeaderWorkerSet: {job_name}")
224+
logging.info(f"Deleted LeaderWorkerSet: {job_name}")
227225
except ApiException as e:
228226
if e.status == 404:
229227
# Job already deleted
230228
pass
231229
else:
232-
logger.warning(
230+
logging.warning(
233231
"Failed to delete LeaderWorkerSet %s: %s",
234232
job_name,
235233
e.reason,

keras_remote/data/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from keras_remote.data.data import Data as Data
2+
from keras_remote.data.data import _make_data_ref as _make_data_ref
3+
from keras_remote.data.data import (
4+
_warn_if_missing_trailing_slash as _warn_if_missing_trailing_slash,
5+
)
6+
from keras_remote.data.data import is_data_ref as is_data_ref

keras_remote/data/data.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
"""Data class for declaring data dependencies in remote functions.
2+
3+
Wraps local file/directory paths or GCS URIs. On the remote side, Data
4+
resolves to a plain filesystem path — the user's function only sees paths.
5+
"""
6+
7+
import hashlib
8+
import os
9+
import posixpath
10+
11+
from absl import logging
12+
13+
14+
class Data:
15+
"""A reference to data that should be available on the remote pod.
16+
17+
Wraps a local file/directory path or a GCS URI. When passed as a function
18+
argument or used in the ``volumes`` decorator parameter, Data is resolved
19+
to a plain filesystem path on the remote side. The user's function code
20+
never needs to know about Data — it just receives paths.
21+
22+
Args:
23+
path: Local file/directory path (absolute or relative) or GCS URI
24+
(``gs://bucket/prefix``).
25+
26+
.. note::
27+
28+
For GCS URIs, a trailing slash indicates a directory (prefix).
29+
``Data("gs://my-bucket/dataset/")`` is treated as a directory,
30+
while ``Data("gs://my-bucket/dataset")`` is treated as a single
31+
object. If you intend to reference a GCS directory, always
32+
include the trailing slash.
33+
34+
Examples::
35+
36+
# Local directory
37+
Data("./my_dataset/")
38+
39+
# Local file
40+
Data("./config.json")
41+
42+
# GCS directory — trailing slash required
43+
Data("gs://my-bucket/datasets/imagenet/")
44+
45+
# GCS single object
46+
Data("gs://my-bucket/datasets/weights.h5")
47+
"""
48+
49+
def __init__(self, path: str):
50+
if not path:
51+
raise ValueError("Data path must not be empty")
52+
self._raw_path = path
53+
if self.is_gcs:
54+
self._resolved_path = path
55+
_warn_if_missing_trailing_slash(path)
56+
else:
57+
self._resolved_path = os.path.abspath(os.path.expanduser(path))
58+
if not os.path.exists(self._resolved_path):
59+
raise FileNotFoundError(
60+
f"Data path does not exist: {path} "
61+
f"(resolved to {self._resolved_path})"
62+
)
63+
64+
@property
65+
def path(self) -> str:
66+
return self._resolved_path
67+
68+
@property
69+
def is_gcs(self) -> bool:
70+
return self._raw_path.startswith("gs://")
71+
72+
@property
73+
def is_dir(self) -> bool:
74+
if self.is_gcs:
75+
return self._raw_path.endswith("/")
76+
return os.path.isdir(self._resolved_path)
77+
78+
def content_hash(self) -> str:
79+
"""SHA-256 hash of all file contents, sorted by relative path.
80+
81+
Includes a type prefix ("dir:" or "file:") to prevent collisions
82+
between a single file and a directory containing only that file.
83+
84+
Symlinked directories are not recursed into (followlinks=False)
85+
to prevent infinite recursion from circular symlinks. Symlinked
86+
files are read and their resolved contents are hashed, so the
87+
hash reflects the actual data visible at runtime.
88+
"""
89+
if self.is_gcs:
90+
raise ValueError("Cannot compute content hash for GCS URI")
91+
92+
h = hashlib.sha256()
93+
if os.path.isdir(self._resolved_path):
94+
h.update(b"dir:")
95+
for root, dirs, files in os.walk(self._resolved_path, followlinks=False):
96+
dirs.sort()
97+
for fname in sorted(files):
98+
fpath = os.path.join(root, fname)
99+
relpath = os.path.relpath(fpath, self._resolved_path)
100+
h.update(relpath.encode("utf-8"))
101+
h.update(b"\0")
102+
with open(fpath, "rb") as f:
103+
while True:
104+
chunk = f.read(65536) # 64 KB chunks
105+
if not chunk:
106+
break
107+
h.update(chunk)
108+
h.update(b"\0")
109+
else:
110+
h.update(b"file:")
111+
h.update(os.path.basename(self._resolved_path).encode("utf-8"))
112+
h.update(b"\0")
113+
with open(self._resolved_path, "rb") as f:
114+
while True:
115+
chunk = f.read(65536)
116+
if not chunk:
117+
break
118+
h.update(chunk)
119+
return h.hexdigest()
120+
121+
def __repr__(self):
122+
return f"Data({self._raw_path!r})"
123+
124+
125+
def _warn_if_missing_trailing_slash(path: str) -> None:
126+
"""Log a warning if a GCS path looks like a directory but has no trailing slash."""
127+
if path.endswith("/"):
128+
return
129+
gcs_path = path.split("//", 1)[1] # strip gs://
130+
last_segment = posixpath.basename(gcs_path)
131+
if last_segment and "." not in last_segment:
132+
logging.warning(
133+
"GCS path %r does not end with '/' but the last segment "
134+
"(%r) has no file extension. If this is a directory "
135+
"(prefix), add a trailing slash: %r",
136+
path,
137+
last_segment,
138+
path + "/",
139+
)
140+
141+
142+
def _make_data_ref(
143+
gcs_uri: str, is_dir: bool, mount_path: str | None = None
144+
) -> dict[str, object]:
145+
"""Create a serializable data reference dict.
146+
147+
These dicts replace Data objects in the payload before serialization.
148+
The remote runner identifies them by the __data_ref__ key.
149+
"""
150+
return {
151+
"__data_ref__": True,
152+
"gcs_uri": gcs_uri,
153+
"is_dir": is_dir,
154+
"mount_path": mount_path,
155+
}
156+
157+
158+
def is_data_ref(obj: object) -> bool:
159+
"""Check if an object is a serialized data reference."""
160+
return isinstance(obj, dict) and obj.get("__data_ref__") is True

0 commit comments

Comments
 (0)