Skip to content

Commit 8d91b67

Browse files
Address reviews
1 parent c11bf88 commit 8d91b67

File tree

5 files changed

+100
-23
lines changed

5 files changed

+100
-23
lines changed

keras_remote/backend/pathways_client.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import time
44

5+
from absl import logging
56
from kubernetes import client
67
from kubernetes.client.rest import ApiException
78

@@ -13,9 +14,6 @@
1314
)
1415
from keras_remote.backend.log_streaming import LogStreamer
1516
from keras_remote.core import accelerators
16-
from keras_remote.infra import infra
17-
18-
logger = infra.logger
1917

2018
LWS_GROUP = "leaderworkerset.x-k8s.io"
2119
LWS_VERSION = "v1"
@@ -40,7 +38,7 @@ def _get_lws_version(group=LWS_GROUP):
4038
# If we didn't find the group, raise ApiException to fallback
4139
raise ApiException(status=404, reason=f"API group {group} not found")
4240
except ApiException:
43-
logger.warning(
41+
logging.warning(
4442
"Failed to retrieve LWS API version from cluster. Defaulting to '%s'",
4543
LWS_VERSION,
4644
)
@@ -108,8 +106,8 @@ def submit_pathways_job(
108106
plural=LWS_PLURAL,
109107
body=lws_manifest,
110108
)
111-
logger.info(f"Submitted Pathways job (LWS): {job_name}")
112-
logger.info(
109+
logging.info(f"Submitted Pathways job (LWS): {job_name}")
110+
logging.info(
113111
"View job with: kubectl get %s %s -n %s", LWS_PLURAL, job_name, namespace
114112
)
115113
return created_lws
@@ -150,11 +148,11 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
150148
try:
151149
pod = core_v1.read_namespaced_pod(leader_pod_name, namespace)
152150
if not logged_running:
153-
logger.info(f"Found pod: {leader_pod_name}")
151+
logging.info(f"Found pod: {leader_pod_name}")
154152
logged_running = True
155153

156154
if pod.status.phase == "Succeeded":
157-
logger.info(f"[REMOTE] Job {job_name} completed successfully")
155+
logging.info(f"[REMOTE] Job {job_name} completed successfully")
158156
return "success"
159157

160158
if pod.status.phase == "Failed":
@@ -163,7 +161,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
163161

164162
elif pod.status.phase == "Pending":
165163
_check_pod_scheduling(core_v1, job_name, namespace)
166-
logger.debug("Pod is Pending...")
164+
logging.debug("Pod is Pending...")
167165

168166
elif pod.status.phase == "Running":
169167
streamer.start(leader_pod_name)
@@ -183,7 +181,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
183181
# Check current state
184182
if container_status.state.terminated:
185183
if container_status.state.terminated.exit_code == 0:
186-
logger.info(f"[REMOTE] Job {job_name} completed successfully")
184+
logging.info(f"[REMOTE] Job {job_name} completed successfully")
187185
return "success"
188186
else:
189187
_print_pod_logs(core_v1, job_name, namespace)
@@ -195,7 +193,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
195193
# Check last state (in case it restarted)
196194
if container_status.last_state.terminated:
197195
if container_status.last_state.terminated.exit_code == 0:
198-
logger.info(
196+
logging.info(
199197
f"[REMOTE] Job {job_name} completed successfully (restarted)"
200198
)
201199
return "success"
@@ -223,13 +221,13 @@ def cleanup_job(job_name, namespace="default"):
223221
plural=LWS_PLURAL,
224222
name=job_name,
225223
)
226-
logger.info(f"Deleted LeaderWorkerSet: {job_name}")
224+
logging.info(f"Deleted LeaderWorkerSet: {job_name}")
227225
except ApiException as e:
228226
if e.status == 404:
229227
# Job already deleted
230228
pass
231229
else:
232-
logger.warning(
230+
logging.warning(
233231
"Failed to delete LeaderWorkerSet %s: %s",
234232
job_name,
235233
e.reason,

keras_remote/data.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import hashlib
88
import os
9+
import posixpath
10+
11+
from absl import logging
912

1013

1114
class Data:
@@ -20,6 +23,14 @@ class Data:
2023
path: Local file/directory path (absolute or relative) or GCS URI
2124
(``gs://bucket/prefix``).
2225
26+
.. note::
27+
28+
For GCS URIs, a trailing slash indicates a directory (prefix).
29+
``Data("gs://my-bucket/dataset/")`` is treated as a directory,
30+
while ``Data("gs://my-bucket/dataset")`` is treated as a single
31+
object. If you intend to reference a GCS directory, always
32+
include the trailing slash.
33+
2334
Examples::
2435
2536
# Local directory
@@ -28,14 +39,20 @@ class Data:
2839
# Local file
2940
Data("./config.json")
3041
31-
# GCS URI
42+
# GCS directory — trailing slash required
3243
Data("gs://my-bucket/datasets/imagenet/")
44+
45+
# GCS single object
46+
Data("gs://my-bucket/datasets/weights.h5")
3347
"""
3448

3549
def __init__(self, path: str):
50+
if not path:
51+
raise ValueError("Data path must not be empty")
3652
self._raw_path = path
3753
if self.is_gcs:
3854
self._resolved_path = path
55+
_warn_if_missing_trailing_slash(path)
3956
else:
4057
self._resolved_path = os.path.abspath(os.path.expanduser(path))
4158
if not os.path.exists(self._resolved_path):
@@ -63,10 +80,11 @@ def content_hash(self) -> str:
6380
6481
Includes a type prefix ("dir:" or "file:") to prevent collisions
6582
between a single file and a directory containing only that file.
66-
Symlinks are not followed (followlinks=False) to ensure
67-
deterministic hashing and prevent circular symlink infinite
68-
recursion. Users with symlinked data should pass the resolved
69-
target path.
83+
84+
Symlinked directories are not recursed into (followlinks=False)
85+
to prevent infinite recursion from circular symlinks. Symlinked
86+
files are read and their resolved contents are hashed, so the
87+
hash reflects the actual data visible at runtime.
7088
"""
7189
if self.is_gcs:
7290
raise ValueError("Cannot compute content hash for GCS URI")
@@ -80,15 +98,18 @@ def content_hash(self) -> str:
8098
fpath = os.path.join(root, fname)
8199
relpath = os.path.relpath(fpath, self._resolved_path)
82100
h.update(relpath.encode("utf-8"))
101+
h.update(b"\0")
83102
with open(fpath, "rb") as f:
84103
while True:
85104
chunk = f.read(65536) # 64 KB chunks
86105
if not chunk:
87106
break
88107
h.update(chunk)
108+
h.update(b"\0")
89109
else:
90110
h.update(b"file:")
91111
h.update(os.path.basename(self._resolved_path).encode("utf-8"))
112+
h.update(b"\0")
92113
with open(self._resolved_path, "rb") as f:
93114
while True:
94115
chunk = f.read(65536)
@@ -101,6 +122,23 @@ def __repr__(self):
101122
return f"Data({self._raw_path!r})"
102123

103124

125+
def _warn_if_missing_trailing_slash(path: str) -> None:
126+
"""Log a warning if a GCS path looks like a directory but has no trailing slash."""
127+
if path.endswith("/"):
128+
return
129+
gcs_path = path.split("//", 1)[1] # strip gs://
130+
last_segment = posixpath.basename(gcs_path)
131+
if last_segment and "." not in last_segment:
132+
logging.warning(
133+
"GCS path %r does not end with '/' but the last segment "
134+
"(%r) has no file extension. If this is a directory "
135+
"(prefix), add a trailing slash: %r",
136+
path,
137+
last_segment,
138+
path + "/",
139+
)
140+
141+
104142
def _make_data_ref(
105143
gcs_uri: str, is_dir: bool, mount_path: str | None = None
106144
) -> dict[str, object]:

keras_remote/data_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def test_gcs_uri_file(self):
4949
self.assertTrue(d.is_gcs)
5050
self.assertFalse(d.is_dir)
5151

52+
def test_empty_path_raises(self):
53+
with self.assertRaises(ValueError):
54+
Data("")
55+
5256
def test_nonexistent_path_raises(self):
5357
with self.assertRaises(FileNotFoundError) as cm:
5458
Data("/nonexistent/path/to/data")
@@ -156,6 +160,46 @@ def test_nested_directory_hash(self):
156160
self.assertIsInstance(h, str)
157161
self.assertEqual(len(h), 64)
158162

163+
def test_filename_content_boundary(self):
164+
"""Filename/content collisions must produce different hashes.
165+
166+
Without a delimiter, file "a" with content "bc" and file "ab" with
167+
content "c" would both hash the byte sequence "abc".
168+
"""
169+
tmp = _make_temp_path(self)
170+
d1 = tmp / "dir1"
171+
d1.mkdir()
172+
(d1 / "a").write_text("bc")
173+
174+
d2 = tmp / "dir2"
175+
d2.mkdir()
176+
(d2 / "ab").write_text("c")
177+
178+
self.assertNotEqual(
179+
Data(str(d1)).content_hash(), Data(str(d2)).content_hash()
180+
)
181+
182+
def test_file_boundary_across_entries(self):
183+
"""Consecutive file entries must not collide.
184+
185+
Without a delimiter between entries, two files ["x" -> "y", "z" -> ""]
186+
and ["x" -> "", "yz" -> ""] would produce the same hash input.
187+
"""
188+
tmp = _make_temp_path(self)
189+
d1 = tmp / "dir1"
190+
d1.mkdir()
191+
(d1 / "x").write_text("y")
192+
(d1 / "z").write_text("")
193+
194+
d2 = tmp / "dir2"
195+
d2.mkdir()
196+
(d2 / "x").write_text("")
197+
(d2 / "yz").write_text("")
198+
199+
self.assertNotEqual(
200+
Data(str(d1)).content_hash(), Data(str(d2)).content_hash()
201+
)
202+
159203
def test_path_included_in_hash(self):
160204
"""Files with same content but different names produce different
161205
hashes."""

keras_remote/infra/infra.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import logging
21
import os
32

4-
logging.basicConfig(level=logging.INFO, format="%(message)s")
5-
logger = logging.getLogger("keras_remote")
6-
73

84
def get_default_project() -> str | None:
95
"""Get project ID from KERAS_REMOTE_PROJECT or GOOGLE_CLOUD_PROJECT."""

keras_remote/utils/storage.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def upload_data(
139139
return data.path
140140

141141
content_hash = data.content_hash()
142+
namespace_prefix = namespace_prefix.strip("/")
142143
cache_prefix = f"{namespace_prefix}/data-cache/{content_hash}"
143144

144145
project = project or get_default_project()
@@ -207,6 +208,6 @@ def _upload_directory(
207208
for root, _dirs, files in os.walk(local_dir):
208209
for fname in files:
209210
local_path = os.path.join(root, fname)
210-
rel_path = os.path.relpath(local_path, local_dir)
211+
rel_path = os.path.relpath(local_path, local_dir).replace(os.sep, "/")
211212
blob = bucket.blob(f"{gcs_prefix}/{rel_path}")
212213
blob.upload_from_filename(local_path)

0 commit comments

Comments
 (0)