Skip to content

Commit 5d2e304

Browse files
Parallelize GCS data transfers and deduplicate downloads (#82)
1 parent 05c1811 commit 5d2e304

File tree

4 files changed

+217
-67
lines changed

4 files changed

+217
-67
lines changed

keras_remote/runner/remote_runner.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
import cloudpickle
1616
from absl import logging
1717
from google.cloud import storage
18+
from google.cloud.storage import transfer_manager
19+
20+
_DOWNLOAD_BATCH_SIZE = 10000
1821

1922
# Base temp directory for remote execution artifacts
2023
TEMP_DIR = tempfile.gettempdir()
@@ -147,6 +150,7 @@ def resolve_volumes(volume_refs, storage_client):
147150
def resolve_data_refs(args, kwargs, storage_client):
148151
"""Recursively resolve data ref dicts in args/kwargs to local paths."""
149152
counter = 0
153+
resolved_uris = {}
150154

151155
def _resolve(obj):
152156
nonlocal counter
@@ -155,14 +159,20 @@ def _resolve(obj):
155159
# Volume-mounted data refs are handled by Kubernetes, skip download
156160
if obj.get("mount_path") is not None:
157161
return obj["mount_path"]
162+
gcs_uri = obj["gcs_uri"]
163+
if gcs_uri in resolved_uris:
164+
return resolved_uris[gcs_uri]
158165
local_dir = os.path.join(DATA_DIR, str(counter))
159166
counter += 1
160167
_download_data(obj, local_dir, storage_client)
161168
# Return file path for single files, directory path otherwise
162169
if not obj["is_dir"]:
163170
files = [f for f in os.listdir(local_dir) if f != ".cache_marker"]
164171
if len(files) == 1:
165-
return os.path.join(local_dir, files[0])
172+
path = os.path.join(local_dir, files[0])
173+
resolved_uris[gcs_uri] = path
174+
return path
175+
resolved_uris[gcs_uri] = local_dir
166176
return local_dir
167177
# Recurse into containers to find nested data refs
168178
if isinstance(obj, dict):
@@ -187,17 +197,39 @@ def _download_data(ref, target_dir, storage_client):
187197
bucket = storage_client.bucket(bucket_name)
188198

189199
blobs = bucket.list_blobs(prefix=prefix + "/")
190-
count = 0
200+
total_downloaded = 0
201+
batch = []
191202
for blob in blobs:
192203
if blob.name.endswith("/") or blob.name.endswith(".cache_marker"):
193204
continue
194-
rel_path = blob.name[len(prefix) + 1 :]
195-
local_path = os.path.join(target_dir, rel_path)
196-
os.makedirs(os.path.dirname(local_path), exist_ok=True)
197-
blob.download_to_filename(local_path)
198-
count += 1
205+
batch.append(blob.name[len(prefix) + 1 :])
206+
if len(batch) >= _DOWNLOAD_BATCH_SIZE:
207+
transfer_manager.download_many_to_path(
208+
bucket,
209+
batch,
210+
destination_directory=target_dir,
211+
blob_name_prefix=prefix + "/",
212+
worker_type=transfer_manager.THREAD,
213+
raise_exception=True,
214+
)
215+
total_downloaded += len(batch)
216+
batch = []
217+
218+
if batch:
219+
transfer_manager.download_many_to_path(
220+
bucket,
221+
batch,
222+
destination_directory=target_dir,
223+
blob_name_prefix=prefix + "/",
224+
worker_type=transfer_manager.THREAD,
225+
raise_exception=True,
226+
)
227+
total_downloaded += len(batch)
199228

200-
logging.info("Downloaded %d files from %s to %s", count, gcs_uri, target_dir)
229+
if total_downloaded:
230+
logging.info(
231+
"Downloaded %d files from %s to %s", total_downloaded, gcs_uri, target_dir
232+
)
201233

202234

203235
def _download_from_gcs(client, gcs_path, local_path):

keras_remote/runner/remote_runner_test.py

Lines changed: 118 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from absl.testing import absltest
1414

1515
from keras_remote.runner.remote_runner import (
16+
_DOWNLOAD_BATCH_SIZE,
1617
_download_data,
1718
_download_from_gcs,
1819
_upload_to_gcs,
@@ -81,6 +82,15 @@ def test_parses_gcs_path(self):
8182

8283

8384
class TestDownloadData(absltest.TestCase):
85+
def setUp(self):
86+
super().setUp()
87+
self.mock_download = self.enterContext(
88+
mock.patch(
89+
"keras_remote.runner.remote_runner.transfer_manager"
90+
".download_many_to_path",
91+
)
92+
)
93+
8494
def test_downloads_files_skips_marker(self):
8595
tmp = _make_temp_path(self)
8696
target = tmp / "output"
@@ -91,9 +101,6 @@ def test_downloads_files_skips_marker(self):
91101

92102
blob_data = MagicMock()
93103
blob_data.name = "prefix/hash/train.csv"
94-
blob_data.download_to_filename = MagicMock(
95-
side_effect=lambda p: pathlib.Path(p).write_text("train")
96-
)
97104

98105
blob_marker = MagicMock()
99106
blob_marker.name = "prefix/hash/.cache_marker"
@@ -115,9 +122,18 @@ def test_downloads_files_skips_marker(self):
115122

116123
_download_data(ref, str(target), mock_client)
117124

118-
blob_data.download_to_filename.assert_called_once()
119-
blob_marker.download_to_filename.assert_not_called()
120-
blob_dir.download_to_filename.assert_not_called()
125+
self.mock_download.assert_called_once()
126+
blob_names = self.mock_download.call_args[0][1]
127+
self.assertEqual(blob_names, ["train.csv"])
128+
self.assertEqual(
129+
self.mock_download.call_args.kwargs["destination_directory"],
130+
str(target),
131+
)
132+
self.assertEqual(
133+
self.mock_download.call_args.kwargs["blob_name_prefix"],
134+
"prefix/hash/",
135+
)
136+
self.assertTrue(self.mock_download.call_args.kwargs["raise_exception"])
121137

122138
def test_creates_subdirectories(self):
123139
tmp = _make_temp_path(self)
@@ -129,12 +145,6 @@ def test_creates_subdirectories(self):
129145

130146
blob = MagicMock()
131147
blob.name = "prefix/hash/sub/deep.csv"
132-
blob.download_to_filename = MagicMock(
133-
side_effect=lambda p: (
134-
pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)
135-
or pathlib.Path(p).write_text("data")
136-
)
137-
)
138148
mock_bucket.list_blobs.return_value = [blob]
139149

140150
ref = {
@@ -145,10 +155,57 @@ def test_creates_subdirectories(self):
145155

146156
_download_data(ref, str(target), mock_client)
147157

148-
# The call should include the nested path
149-
call_path = blob.download_to_filename.call_args[0][0]
150-
self.assertIn("sub", call_path)
151-
self.assertTrue(call_path.endswith("deep.csv"))
158+
blob_names = self.mock_download.call_args[0][1]
159+
self.assertEqual(blob_names, ["sub/deep.csv"])
160+
161+
def test_large_listing_downloads_in_batches(self):
162+
tmp = _make_temp_path(self)
163+
target = tmp / "output"
164+
165+
mock_client = MagicMock()
166+
mock_bucket = MagicMock()
167+
mock_client.bucket.return_value = mock_bucket
168+
169+
num_blobs = _DOWNLOAD_BATCH_SIZE + 5
170+
blobs = []
171+
for i in range(num_blobs):
172+
blob = MagicMock()
173+
blob.name = f"prefix/hash/file_{i}.csv"
174+
blobs.append(blob)
175+
mock_bucket.list_blobs.return_value = blobs
176+
177+
ref = {
178+
"__data_ref__": True,
179+
"gcs_uri": "gs://bucket/prefix/hash",
180+
"is_dir": True,
181+
}
182+
183+
_download_data(ref, str(target), mock_client)
184+
185+
self.assertEqual(self.mock_download.call_count, 2)
186+
first_batch = self.mock_download.call_args_list[0][0][1]
187+
second_batch = self.mock_download.call_args_list[1][0][1]
188+
self.assertEqual(len(first_batch), _DOWNLOAD_BATCH_SIZE)
189+
self.assertEqual(len(second_batch), 5)
190+
191+
def test_empty_listing_is_noop(self):
192+
tmp = _make_temp_path(self)
193+
target = tmp / "output"
194+
195+
mock_client = MagicMock()
196+
mock_bucket = MagicMock()
197+
mock_client.bucket.return_value = mock_bucket
198+
mock_bucket.list_blobs.return_value = []
199+
200+
ref = {
201+
"__data_ref__": True,
202+
"gcs_uri": "gs://bucket/prefix/hash",
203+
"is_dir": True,
204+
}
205+
206+
_download_data(ref, str(target), mock_client)
207+
208+
self.mock_download.assert_not_called()
152209

153210

154211
class TestResolveDataRefs(absltest.TestCase):
@@ -200,16 +257,6 @@ def test_nested_refs_in_list(self):
200257

201258
def test_single_file_returns_file_path(self):
202259
tmp = _make_temp_path(self)
203-
mock_client = MagicMock()
204-
mock_bucket = MagicMock()
205-
mock_client.bucket.return_value = mock_bucket
206-
207-
blob = MagicMock()
208-
blob.name = "prefix/hash/config.json"
209-
blob.download_to_filename = MagicMock(
210-
side_effect=lambda p: pathlib.Path(p).write_text("{}")
211-
)
212-
mock_bucket.list_blobs.return_value = [blob]
213260

214261
ref = {
215262
"__data_ref__": True,
@@ -218,14 +265,55 @@ def test_single_file_returns_file_path(self):
218265
"mount_path": None,
219266
}
220267

221-
with mock.patch(
222-
"keras_remote.runner.remote_runner.DATA_DIR",
223-
str(tmp / "data"),
268+
def fake_dl(ref, target_dir, client):
269+
os.makedirs(target_dir, exist_ok=True)
270+
pathlib.Path(os.path.join(target_dir, "config.json")).write_text("{}")
271+
272+
with (
273+
mock.patch(
274+
"keras_remote.runner.remote_runner.DATA_DIR",
275+
str(tmp / "data"),
276+
),
277+
mock.patch(
278+
"keras_remote.runner.remote_runner._download_data",
279+
side_effect=fake_dl,
280+
),
224281
):
225-
args, _ = resolve_data_refs((ref,), {}, mock_client)
282+
args, _ = resolve_data_refs((ref,), {}, MagicMock())
226283

227284
self.assertTrue(args[0].endswith("config.json"))
228285

286+
def test_duplicate_uri_downloaded_once(self):
287+
tmp = _make_temp_path(self)
288+
289+
ref = {
290+
"__data_ref__": True,
291+
"gcs_uri": "gs://b/cache/hash",
292+
"is_dir": True,
293+
"mount_path": None,
294+
}
295+
296+
def fake_dl(r, target_dir, client):
297+
os.makedirs(target_dir, exist_ok=True)
298+
299+
with (
300+
mock.patch(
301+
"keras_remote.runner.remote_runner.DATA_DIR",
302+
str(tmp / "data"),
303+
),
304+
mock.patch(
305+
"keras_remote.runner.remote_runner._download_data",
306+
side_effect=fake_dl,
307+
) as mock_dl,
308+
):
309+
args, kwargs = resolve_data_refs((ref, ref), {"d": ref}, MagicMock())
310+
311+
# Downloaded only once despite three references
312+
mock_dl.assert_called_once()
313+
# All resolved paths point to the same directory
314+
self.assertEqual(args[0], args[1])
315+
self.assertEqual(args[0], kwargs["d"])
316+
229317
def test_non_ref_dict_preserved(self):
230318
mock_client = MagicMock()
231319
args, kwargs = resolve_data_refs(({"key": "value"},), {"x": 1}, mock_client)

keras_remote/utils/storage.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from absl import logging
99
from google.cloud import storage
10+
from google.cloud.storage import transfer_manager
1011

1112
from keras_remote.data import Data
1213
from keras_remote.infra.infra import get_default_project
@@ -205,9 +206,23 @@ def _upload_directory(
205206
bucket: storage.Bucket, local_dir: str, gcs_prefix: str
206207
) -> None:
207208
"""Upload a local directory to GCS preserving structure."""
209+
filenames = []
208210
for root, _dirs, files in os.walk(local_dir):
209211
for fname in files:
210212
local_path = os.path.join(root, fname)
211213
rel_path = os.path.relpath(local_path, local_dir).replace(os.sep, "/")
212-
blob = bucket.blob(f"{gcs_prefix}/{rel_path}")
213-
blob.upload_from_filename(local_path)
214+
filenames.append(rel_path)
215+
216+
if not filenames:
217+
return
218+
219+
logging.info("Uploading %d files to GCS...", len(filenames))
220+
221+
transfer_manager.upload_many_from_filenames(
222+
bucket,
223+
filenames,
224+
source_directory=local_dir,
225+
blob_name_prefix=f"{gcs_prefix}/",
226+
worker_type=transfer_manager.THREAD,
227+
raise_exception=True,
228+
)

0 commit comments

Comments
 (0)