Skip to content

Commit b8407e3

Browse files
Add Data class and content-addressed storage upload
Introduces the Data class for declaring data dependencies (local paths or GCS URIs) and upload_data() for content-hash-based caching in GCS.
1 parent 4300d82 commit b8407e3

File tree

5 files changed

+587
-0
lines changed

5 files changed

+587
-0
lines changed

keras_remote/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "0")
77

88
from keras_remote.core.core import run as run
9+
from keras_remote.data import Data as Data

keras_remote/data.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Data class for declaring data dependencies in remote functions.
2+
3+
Wraps local file/directory paths or GCS URIs. On the remote side, Data
4+
resolves to a plain filesystem path — the user's function only sees paths.
5+
"""
6+
7+
import hashlib
8+
import os
9+
10+
11+
class Data:
12+
"""A reference to data that should be available on the remote pod.
13+
14+
Wraps a local file/directory path or a GCS URI. When passed as a function
15+
argument or used in the ``volumes`` decorator parameter, Data is resolved
16+
to a plain filesystem path on the remote side. The user's function code
17+
never needs to know about Data — it just receives paths.
18+
19+
Args:
20+
path: Local file/directory path (absolute or relative) or GCS URI
21+
(``gs://bucket/prefix``).
22+
23+
Examples::
24+
25+
# Local directory
26+
Data("./my_dataset/")
27+
28+
# Local file
29+
Data("./config.json")
30+
31+
# GCS URI
32+
Data("gs://my-bucket/datasets/imagenet/")
33+
"""
34+
35+
def __init__(self, path: str):
36+
self._raw_path = path
37+
if self.is_gcs:
38+
self._resolved_path = path
39+
else:
40+
self._resolved_path = os.path.abspath(os.path.expanduser(path))
41+
if not os.path.exists(self._resolved_path):
42+
raise FileNotFoundError(
43+
f"Data path does not exist: {path} "
44+
f"(resolved to {self._resolved_path})"
45+
)
46+
47+
@property
48+
def path(self) -> str:
49+
return self._resolved_path
50+
51+
@property
52+
def is_gcs(self) -> bool:
53+
return self._raw_path.startswith("gs://")
54+
55+
@property
56+
def is_dir(self) -> bool:
57+
if self.is_gcs:
58+
return self._raw_path.endswith("/")
59+
return os.path.isdir(self._resolved_path)
60+
61+
def content_hash(self) -> str:
62+
"""SHA-256 hash of all file contents, sorted by relative path.
63+
64+
Includes a type prefix ("dir:" or "file:") to prevent collisions
65+
between a single file and a directory containing only that file.
66+
Symlinks are not followed (followlinks=False) to ensure
67+
deterministic hashing and prevent circular symlink infinite
68+
recursion. Users with symlinked data should pass the resolved
69+
target path.
70+
"""
71+
if self.is_gcs:
72+
raise ValueError("Cannot compute content hash for GCS URI")
73+
74+
h = hashlib.sha256()
75+
if os.path.isdir(self._resolved_path):
76+
h.update(b"dir:")
77+
for root, dirs, files in os.walk(self._resolved_path, followlinks=False):
78+
dirs.sort()
79+
for fname in sorted(files):
80+
fpath = os.path.join(root, fname)
81+
relpath = os.path.relpath(fpath, self._resolved_path)
82+
h.update(relpath.encode("utf-8"))
83+
with open(fpath, "rb") as f:
84+
while True:
85+
chunk = f.read(65536) # 64 KB chunks
86+
if not chunk:
87+
break
88+
h.update(chunk)
89+
else:
90+
h.update(b"file:")
91+
h.update(os.path.basename(self._resolved_path).encode("utf-8"))
92+
with open(self._resolved_path, "rb") as f:
93+
while True:
94+
chunk = f.read(65536)
95+
if not chunk:
96+
break
97+
h.update(chunk)
98+
return h.hexdigest()
99+
100+
def __repr__(self):
101+
return f"Data({self._raw_path!r})"
102+
103+
104+
def _make_data_ref(gcs_uri, is_dir, mount_path=None):
105+
"""Create a serializable data reference dict.
106+
107+
These dicts replace Data objects in the payload before serialization.
108+
The remote runner identifies them by the __data_ref__ key.
109+
"""
110+
return {
111+
"__data_ref__": True,
112+
"gcs_uri": gcs_uri,
113+
"is_dir": is_dir,
114+
"mount_path": mount_path,
115+
}
116+
117+
118+
def is_data_ref(obj):
119+
"""Check if an object is a serialized data reference."""
120+
return isinstance(obj, dict) and obj.get("__data_ref__") is True

keras_remote/data_test.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""Tests for keras_remote.data — Data class and helpers."""
2+
3+
import os
4+
import pathlib
5+
import tempfile
6+
7+
from absl.testing import absltest
8+
9+
from keras_remote.data import Data, _make_data_ref, is_data_ref
10+
11+
12+
def _make_temp_path(test_case):
13+
"""Create a temp directory that is cleaned up after the test."""
14+
td = tempfile.TemporaryDirectory()
15+
test_case.addCleanup(td.cleanup)
16+
return pathlib.Path(td.name)
17+
18+
19+
class TestDataConstructor(absltest.TestCase):
20+
def test_local_file(self):
21+
tmp = _make_temp_path(self)
22+
f = tmp / "data.csv"
23+
f.write_text("a,b\n1,2\n")
24+
d = Data(str(f))
25+
26+
self.assertEqual(d.path, str(f))
27+
self.assertFalse(d.is_gcs)
28+
self.assertFalse(d.is_dir)
29+
30+
def test_local_directory(self):
31+
tmp = _make_temp_path(self)
32+
d_dir = tmp / "dataset"
33+
d_dir.mkdir()
34+
(d_dir / "train.csv").write_text("data")
35+
d = Data(str(d_dir))
36+
37+
self.assertEqual(d.path, str(d_dir))
38+
self.assertFalse(d.is_gcs)
39+
self.assertTrue(d.is_dir)
40+
41+
def test_gcs_uri_directory(self):
42+
d = Data("gs://my-bucket/data/")
43+
self.assertEqual(d.path, "gs://my-bucket/data/")
44+
self.assertTrue(d.is_gcs)
45+
self.assertTrue(d.is_dir)
46+
47+
def test_gcs_uri_file(self):
48+
d = Data("gs://my-bucket/data/file.csv")
49+
self.assertTrue(d.is_gcs)
50+
self.assertFalse(d.is_dir)
51+
52+
def test_nonexistent_path_raises(self):
53+
with self.assertRaises(FileNotFoundError) as cm:
54+
Data("/nonexistent/path/to/data")
55+
self.assertIn("/nonexistent/path/to/data", str(cm.exception))
56+
57+
def test_relative_path_resolved(self):
58+
tmp = _make_temp_path(self)
59+
f = tmp / "file.txt"
60+
f.write_text("content")
61+
# Use relative-like path by going through expanduser
62+
d = Data(str(f))
63+
self.assertTrue(os.path.isabs(d.path))
64+
65+
def test_repr(self):
66+
d = Data("gs://bucket/path/")
67+
self.assertEqual(repr(d), "Data('gs://bucket/path/')")
68+
69+
70+
class TestContentHash(absltest.TestCase):
71+
def test_deterministic_file_hash(self):
72+
tmp = _make_temp_path(self)
73+
f = tmp / "data.csv"
74+
f.write_text("hello,world\n")
75+
76+
d1 = Data(str(f))
77+
d2 = Data(str(f))
78+
self.assertEqual(d1.content_hash(), d2.content_hash())
79+
80+
def test_different_content_different_hash(self):
81+
tmp = _make_temp_path(self)
82+
f1 = tmp / "a.csv"
83+
f1.write_text("content_a")
84+
f2 = tmp / "b.csv"
85+
f2.write_text("content_b")
86+
87+
self.assertNotEqual(
88+
Data(str(f1)).content_hash(), Data(str(f2)).content_hash()
89+
)
90+
91+
def test_deterministic_dir_hash(self):
92+
tmp = _make_temp_path(self)
93+
d = tmp / "dataset"
94+
d.mkdir()
95+
(d / "train.csv").write_text("train data")
96+
(d / "val.csv").write_text("val data")
97+
98+
d1 = Data(str(d))
99+
d2 = Data(str(d))
100+
self.assertEqual(d1.content_hash(), d2.content_hash())
101+
102+
def test_dir_content_change_changes_hash(self):
103+
tmp = _make_temp_path(self)
104+
d = tmp / "dataset"
105+
d.mkdir()
106+
(d / "train.csv").write_text("original")
107+
108+
h1 = Data(str(d)).content_hash()
109+
(d / "train.csv").write_text("modified")
110+
h2 = Data(str(d)).content_hash()
111+
112+
self.assertNotEqual(h1, h2)
113+
114+
def test_file_vs_dir_different_hash(self):
115+
"""A single file and a directory containing only that file should
116+
produce different hashes due to the type prefix."""
117+
tmp = _make_temp_path(self)
118+
119+
# Single file
120+
f = tmp / "file.csv"
121+
f.write_text("same content")
122+
123+
# Directory containing only that file
124+
d = tmp / "dir"
125+
d.mkdir()
126+
(d / "file.csv").write_text("same content")
127+
128+
file_hash = Data(str(f)).content_hash()
129+
dir_hash = Data(str(d)).content_hash()
130+
self.assertNotEqual(file_hash, dir_hash)
131+
132+
def test_empty_directory(self):
133+
tmp = _make_temp_path(self)
134+
d = tmp / "empty"
135+
d.mkdir()
136+
137+
# Should not raise, should return a valid hash
138+
h = Data(str(d)).content_hash()
139+
self.assertIsInstance(h, str)
140+
self.assertEqual(len(h), 64) # SHA-256 hex digest
141+
142+
def test_gcs_uri_raises(self):
143+
d = Data("gs://bucket/data/")
144+
with self.assertRaises(ValueError):
145+
d.content_hash()
146+
147+
def test_nested_directory_hash(self):
148+
tmp = _make_temp_path(self)
149+
d = tmp / "nested"
150+
sub = d / "sub"
151+
sub.mkdir(parents=True)
152+
(d / "a.txt").write_text("a")
153+
(sub / "b.txt").write_text("b")
154+
155+
h = Data(str(d)).content_hash()
156+
self.assertIsInstance(h, str)
157+
self.assertEqual(len(h), 64)
158+
159+
def test_path_included_in_hash(self):
160+
"""Files with same content but different names produce different
161+
hashes."""
162+
tmp = _make_temp_path(self)
163+
d1 = tmp / "dir1"
164+
d1.mkdir()
165+
(d1 / "alpha.csv").write_text("same")
166+
167+
d2 = tmp / "dir2"
168+
d2.mkdir()
169+
(d2 / "beta.csv").write_text("same")
170+
171+
self.assertNotEqual(
172+
Data(str(d1)).content_hash(), Data(str(d2)).content_hash()
173+
)
174+
175+
176+
class TestMakeDataRef(absltest.TestCase):
177+
def test_basic_ref(self):
178+
ref = _make_data_ref("gs://b/prefix", True)
179+
self.assertTrue(ref["__data_ref__"])
180+
self.assertEqual(ref["gcs_uri"], "gs://b/prefix")
181+
self.assertTrue(ref["is_dir"])
182+
self.assertIsNone(ref["mount_path"])
183+
184+
def test_with_mount_path(self):
185+
ref = _make_data_ref("gs://b/p", False, mount_path="/data")
186+
self.assertEqual(ref["mount_path"], "/data")
187+
self.assertFalse(ref["is_dir"])
188+
189+
190+
class TestIsDataRef(absltest.TestCase):
191+
def test_valid_ref(self):
192+
ref = {"__data_ref__": True, "gcs_uri": "gs://b/p", "is_dir": True}
193+
self.assertTrue(is_data_ref(ref))
194+
195+
def test_plain_dict(self):
196+
self.assertFalse(is_data_ref({"key": "value"}))
197+
198+
def test_non_dict(self):
199+
self.assertFalse(is_data_ref("string"))
200+
self.assertFalse(is_data_ref(42))
201+
self.assertFalse(is_data_ref(None))
202+
203+
204+
if __name__ == "__main__":
205+
absltest.main()

0 commit comments

Comments
 (0)