Skip to content

Commit 3163230

Browse files
Adds parallel data upload/download and remote-side deduplication
1 parent b57bf26 commit 3163230

File tree

4 files changed

+173
-68
lines changed

4 files changed

+173
-68
lines changed

keras_remote/runner/remote_runner.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import cloudpickle
1616
from absl import logging
1717
from google.cloud import storage
18+
from google.cloud.storage import transfer_manager
1819

1920
# Base temp directory for remote execution artifacts
2021
TEMP_DIR = tempfile.gettempdir()
@@ -147,6 +148,7 @@ def resolve_volumes(volume_refs, storage_client):
147148
def resolve_data_refs(args, kwargs, storage_client):
148149
"""Recursively resolve data ref dicts in args/kwargs to local paths."""
149150
counter = 0
151+
resolved_uris = {}
150152

151153
def _resolve(obj):
152154
nonlocal counter
@@ -155,14 +157,20 @@ def _resolve(obj):
155157
# Volume-mounted data refs are handled by Kubernetes, skip download
156158
if obj.get("mount_path") is not None:
157159
return obj["mount_path"]
160+
gcs_uri = obj["gcs_uri"]
161+
if gcs_uri in resolved_uris:
162+
return resolved_uris[gcs_uri]
158163
local_dir = os.path.join(DATA_DIR, str(counter))
159164
counter += 1
160165
_download_data(obj, local_dir, storage_client)
161166
# Return file path for single files, directory path otherwise
162167
if not obj["is_dir"]:
163168
files = [f for f in os.listdir(local_dir) if f != ".cache_marker"]
164169
if len(files) == 1:
165-
return os.path.join(local_dir, files[0])
170+
path = os.path.join(local_dir, files[0])
171+
resolved_uris[gcs_uri] = path
172+
return path
173+
resolved_uris[gcs_uri] = local_dir
166174
return local_dir
167175
# Recurse into containers to find nested data refs
168176
if isinstance(obj, dict):
@@ -187,17 +195,27 @@ def _download_data(ref, target_dir, storage_client):
187195
bucket = storage_client.bucket(bucket_name)
188196

189197
blobs = bucket.list_blobs(prefix=prefix + "/")
190-
count = 0
198+
blob_names = []
191199
for blob in blobs:
192200
if blob.name.endswith("/") or blob.name.endswith(".cache_marker"):
193201
continue
194-
rel_path = blob.name[len(prefix) + 1 :]
195-
local_path = os.path.join(target_dir, rel_path)
196-
os.makedirs(os.path.dirname(local_path), exist_ok=True)
197-
blob.download_to_filename(local_path)
198-
count += 1
199-
200-
logging.info("Downloaded %d files from %s to %s", count, gcs_uri, target_dir)
202+
blob_names.append(blob.name[len(prefix) + 1 :])
203+
204+
if not blob_names:
205+
return
206+
207+
transfer_manager.download_many_to_path(
208+
bucket,
209+
blob_names,
210+
destination_directory=target_dir,
211+
blob_name_prefix=prefix + "/",
212+
worker_type=transfer_manager.THREAD,
213+
raise_exception=True,
214+
)
215+
216+
logging.info(
217+
"Downloaded %d files from %s to %s", len(blob_names), gcs_uri, target_dir
218+
)
201219

202220

203221
def _download_from_gcs(client, gcs_path, local_path):

keras_remote/runner/remote_runner_test.py

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ def test_parses_gcs_path(self):
8181

8282

8383
class TestDownloadData(absltest.TestCase):
84+
def setUp(self):
85+
super().setUp()
86+
self.mock_download = self.enterContext(
87+
mock.patch(
88+
"keras_remote.runner.remote_runner.transfer_manager"
89+
".download_many_to_path",
90+
)
91+
)
92+
8493
def test_downloads_files_skips_marker(self):
8594
tmp = _make_temp_path(self)
8695
target = tmp / "output"
@@ -91,9 +100,6 @@ def test_downloads_files_skips_marker(self):
91100

92101
blob_data = MagicMock()
93102
blob_data.name = "prefix/hash/train.csv"
94-
blob_data.download_to_filename = MagicMock(
95-
side_effect=lambda p: pathlib.Path(p).write_text("train")
96-
)
97103

98104
blob_marker = MagicMock()
99105
blob_marker.name = "prefix/hash/.cache_marker"
@@ -115,9 +121,18 @@ def test_downloads_files_skips_marker(self):
115121

116122
_download_data(ref, str(target), mock_client)
117123

118-
blob_data.download_to_filename.assert_called_once()
119-
blob_marker.download_to_filename.assert_not_called()
120-
blob_dir.download_to_filename.assert_not_called()
124+
self.mock_download.assert_called_once()
125+
blob_names = self.mock_download.call_args[0][1]
126+
self.assertEqual(blob_names, ["train.csv"])
127+
self.assertEqual(
128+
self.mock_download.call_args.kwargs["destination_directory"],
129+
str(target),
130+
)
131+
self.assertEqual(
132+
self.mock_download.call_args.kwargs["blob_name_prefix"],
133+
"prefix/hash/",
134+
)
135+
self.assertTrue(self.mock_download.call_args.kwargs["raise_exception"])
121136

122137
def test_creates_subdirectories(self):
123138
tmp = _make_temp_path(self)
@@ -129,12 +144,6 @@ def test_creates_subdirectories(self):
129144

130145
blob = MagicMock()
131146
blob.name = "prefix/hash/sub/deep.csv"
132-
blob.download_to_filename = MagicMock(
133-
side_effect=lambda p: (
134-
pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)
135-
or pathlib.Path(p).write_text("data")
136-
)
137-
)
138147
mock_bucket.list_blobs.return_value = [blob]
139148

140149
ref = {
@@ -145,10 +154,27 @@ def test_creates_subdirectories(self):
145154

146155
_download_data(ref, str(target), mock_client)
147156

148-
# The call should include the nested path
149-
call_path = blob.download_to_filename.call_args[0][0]
150-
self.assertIn("sub", call_path)
151-
self.assertTrue(call_path.endswith("deep.csv"))
157+
blob_names = self.mock_download.call_args[0][1]
158+
self.assertEqual(blob_names, ["sub/deep.csv"])
159+
160+
def test_empty_listing_is_noop(self):
161+
tmp = _make_temp_path(self)
162+
target = tmp / "output"
163+
164+
mock_client = MagicMock()
165+
mock_bucket = MagicMock()
166+
mock_client.bucket.return_value = mock_bucket
167+
mock_bucket.list_blobs.return_value = []
168+
169+
ref = {
170+
"__data_ref__": True,
171+
"gcs_uri": "gs://bucket/prefix/hash",
172+
"is_dir": True,
173+
}
174+
175+
_download_data(ref, str(target), mock_client)
176+
177+
self.mock_download.assert_not_called()
152178

153179

154180
class TestResolveDataRefs(absltest.TestCase):
@@ -200,16 +226,6 @@ def test_nested_refs_in_list(self):
200226

201227
def test_single_file_returns_file_path(self):
202228
tmp = _make_temp_path(self)
203-
mock_client = MagicMock()
204-
mock_bucket = MagicMock()
205-
mock_client.bucket.return_value = mock_bucket
206-
207-
blob = MagicMock()
208-
blob.name = "prefix/hash/config.json"
209-
blob.download_to_filename = MagicMock(
210-
side_effect=lambda p: pathlib.Path(p).write_text("{}")
211-
)
212-
mock_bucket.list_blobs.return_value = [blob]
213229

214230
ref = {
215231
"__data_ref__": True,
@@ -218,14 +234,55 @@ def test_single_file_returns_file_path(self):
218234
"mount_path": None,
219235
}
220236

221-
with mock.patch(
222-
"keras_remote.runner.remote_runner.DATA_DIR",
223-
str(tmp / "data"),
237+
def fake_dl(ref, target_dir, client):
238+
os.makedirs(target_dir, exist_ok=True)
239+
pathlib.Path(os.path.join(target_dir, "config.json")).write_text("{}")
240+
241+
with (
242+
mock.patch(
243+
"keras_remote.runner.remote_runner.DATA_DIR",
244+
str(tmp / "data"),
245+
),
246+
mock.patch(
247+
"keras_remote.runner.remote_runner._download_data",
248+
side_effect=fake_dl,
249+
),
224250
):
225-
args, _ = resolve_data_refs((ref,), {}, mock_client)
251+
args, _ = resolve_data_refs((ref,), {}, MagicMock())
226252

227253
self.assertTrue(args[0].endswith("config.json"))
228254

255+
def test_duplicate_uri_downloaded_once(self):
256+
tmp = _make_temp_path(self)
257+
258+
ref = {
259+
"__data_ref__": True,
260+
"gcs_uri": "gs://b/cache/hash",
261+
"is_dir": True,
262+
"mount_path": None,
263+
}
264+
265+
def fake_dl(r, target_dir, client):
266+
os.makedirs(target_dir, exist_ok=True)
267+
268+
with (
269+
mock.patch(
270+
"keras_remote.runner.remote_runner.DATA_DIR",
271+
str(tmp / "data"),
272+
),
273+
mock.patch(
274+
"keras_remote.runner.remote_runner._download_data",
275+
side_effect=fake_dl,
276+
) as mock_dl,
277+
):
278+
args, kwargs = resolve_data_refs((ref, ref), {"d": ref}, MagicMock())
279+
280+
# Downloaded only once despite three references
281+
mock_dl.assert_called_once()
282+
# All resolved paths point to the same directory
283+
self.assertEqual(args[0], args[1])
284+
self.assertEqual(args[0], kwargs["d"])
285+
229286
def test_non_ref_dict_preserved(self):
230287
mock_client = MagicMock()
231288
args, kwargs = resolve_data_refs(({"key": "value"},), {"x": 1}, mock_client)

keras_remote/utils/storage.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from absl import logging
99
from google.cloud import storage
10+
from google.cloud.storage import transfer_manager
1011

1112
from keras_remote.data import Data
1213
from keras_remote.infra.infra import get_default_project
@@ -205,9 +206,23 @@ def _upload_directory(
205206
bucket: storage.Bucket, local_dir: str, gcs_prefix: str
206207
) -> None:
207208
"""Upload a local directory to GCS preserving structure."""
209+
filenames = []
208210
for root, _dirs, files in os.walk(local_dir):
209211
for fname in files:
210212
local_path = os.path.join(root, fname)
211213
rel_path = os.path.relpath(local_path, local_dir).replace(os.sep, "/")
212-
blob = bucket.blob(f"{gcs_prefix}/{rel_path}")
213-
blob.upload_from_filename(local_path)
214+
filenames.append(rel_path)
215+
216+
if not filenames:
217+
return
218+
219+
logging.info("Uploading %d files to GCS...", len(filenames))
220+
221+
transfer_manager.upload_many_from_filenames(
222+
bucket,
223+
filenames,
224+
source_directory=local_dir,
225+
blob_name_prefix=f"{gcs_prefix}/",
226+
worker_type=transfer_manager.THREAD,
227+
raise_exception=True,
228+
)

keras_remote/utils/storage_test.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -202,33 +202,33 @@ def track_blob(name):
202202
self.assertIn(marker_name, blobs)
203203
blobs[marker_name].upload_from_string.assert_called_once_with("")
204204

205-
def test_cache_miss_uploads_directory(self):
205+
@mock.patch(
206+
"keras_remote.utils.storage.transfer_manager.upload_many_from_filenames",
207+
)
208+
def test_cache_miss_uploads_directory(self, mock_upload):
206209
tmp = _make_temp_path(self)
207210
d_dir = tmp / "dataset"
208211
d_dir.mkdir()
209212
(d_dir / "train.csv").write_text("train")
210213
(d_dir / "val.csv").write_text("val")
211214
d = Data(str(d_dir))
215+
content_hash = d.content_hash()
212216

213217
mock_bucket = self.mock_gcs.bucket.return_value
214-
blobs = {}
215-
216-
def track_blob(name):
217-
b = MagicMock()
218-
blobs[name] = b
219-
if name.endswith(".cache_marker"):
220-
b.exists.return_value = False
221-
return b
222-
223-
mock_bucket.blob.side_effect = track_blob
218+
marker_blob = MagicMock()
219+
marker_blob.exists.return_value = False
220+
mock_bucket.blob.return_value = marker_blob
224221

225222
result = upload_data("jobs-bucket", d, project="proj")
226223

227-
self.assertIn("gs://jobs-bucket/default/data-cache/", result)
228-
# Both files + marker should have blobs
229-
blob_names = list(blobs.keys())
230-
csv_blobs = [n for n in blob_names if n.endswith(".csv")]
231-
self.assertEqual(len(csv_blobs), 2)
224+
expected_prefix = f"default/data-cache/{content_hash}"
225+
self.assertEqual(result, f"gs://jobs-bucket/{expected_prefix}")
226+
# Directory upload via transfer_manager
227+
mock_upload.assert_called_once()
228+
filenames = sorted(mock_upload.call_args[0][1])
229+
self.assertEqual(filenames, ["train.csv", "val.csv"])
230+
# Marker written after upload
231+
marker_blob.upload_from_string.assert_called_once_with("")
232232

233233
def test_custom_namespace(self):
234234
tmp = _make_temp_path(self)
@@ -273,7 +273,16 @@ def test_empty_directory(self):
273273
self.assertEqual(_compute_total_size(str(d)), 0)
274274

275275

276-
class TestUploadDirectory(_GcsTestBase):
276+
class TestUploadDirectory(absltest.TestCase):
277+
def setUp(self):
278+
super().setUp()
279+
self.mock_upload = self.enterContext(
280+
mock.patch(
281+
"keras_remote.utils.storage.transfer_manager"
282+
".upload_many_from_filenames",
283+
)
284+
)
285+
277286
def test_preserves_structure(self):
278287
tmp = _make_temp_path(self)
279288
d = tmp / "dataset"
@@ -283,21 +292,27 @@ def test_preserves_structure(self):
283292
(sub / "b.csv").write_text("b")
284293

285294
mock_bucket = MagicMock()
286-
uploaded = {}
287295

288-
def track_blob(name):
289-
b = MagicMock()
290-
uploaded[name] = b
291-
return b
296+
_upload_directory(mock_bucket, str(d), "prefix/hash")
292297

293-
mock_bucket.blob.side_effect = track_blob
298+
self.mock_upload.assert_called_once()
299+
call_kwargs = self.mock_upload.call_args
300+
filenames = sorted(call_kwargs[0][1]) # second positional arg
301+
self.assertEqual(filenames, ["a.csv", "sub/b.csv"])
302+
self.assertEqual(call_kwargs.kwargs["source_directory"], str(d))
303+
self.assertEqual(call_kwargs.kwargs["blob_name_prefix"], "prefix/hash/")
304+
self.assertTrue(call_kwargs.kwargs["raise_exception"])
305+
306+
def test_empty_directory_is_noop(self):
307+
tmp = _make_temp_path(self)
308+
d = tmp / "empty_dataset"
309+
d.mkdir()
310+
311+
mock_bucket = MagicMock()
294312

295313
_upload_directory(mock_bucket, str(d), "prefix/hash")
296314

297-
self.assertIn("prefix/hash/a.csv", uploaded)
298-
self.assertIn("prefix/hash/sub/b.csv", uploaded)
299-
for blob in uploaded.values():
300-
blob.upload_from_filename.assert_called_once()
315+
self.mock_upload.assert_not_called()
301316

302317

303318
if __name__ == "__main__":

0 commit comments

Comments
 (0)