1313from absl .testing import absltest
1414
1515from keras_remote .runner .remote_runner import (
16+ _DOWNLOAD_BATCH_SIZE ,
1617 _download_data ,
1718 _download_from_gcs ,
1819 _upload_to_gcs ,
@@ -81,6 +82,15 @@ def test_parses_gcs_path(self):
8182
8283
8384class TestDownloadData (absltest .TestCase ):
85+ def setUp (self ):
86+ super ().setUp ()
87+ self .mock_download = self .enterContext (
88+ mock .patch (
89+ "keras_remote.runner.remote_runner.transfer_manager"
90+ ".download_many_to_path" ,
91+ )
92+ )
93+
8494 def test_downloads_files_skips_marker (self ):
8595 tmp = _make_temp_path (self )
8696 target = tmp / "output"
@@ -91,9 +101,6 @@ def test_downloads_files_skips_marker(self):
91101
92102 blob_data = MagicMock ()
93103 blob_data .name = "prefix/hash/train.csv"
94- blob_data .download_to_filename = MagicMock (
95- side_effect = lambda p : pathlib .Path (p ).write_text ("train" )
96- )
97104
98105 blob_marker = MagicMock ()
99106 blob_marker .name = "prefix/hash/.cache_marker"
@@ -115,9 +122,18 @@ def test_downloads_files_skips_marker(self):
115122
116123 _download_data (ref , str (target ), mock_client )
117124
118- blob_data .download_to_filename .assert_called_once ()
119- blob_marker .download_to_filename .assert_not_called ()
120- blob_dir .download_to_filename .assert_not_called ()
125+ self .mock_download .assert_called_once ()
126+ blob_names = self .mock_download .call_args [0 ][1 ]
127+ self .assertEqual (blob_names , ["train.csv" ])
128+ self .assertEqual (
129+ self .mock_download .call_args .kwargs ["destination_directory" ],
130+ str (target ),
131+ )
132+ self .assertEqual (
133+ self .mock_download .call_args .kwargs ["blob_name_prefix" ],
134+ "prefix/hash/" ,
135+ )
136+ self .assertTrue (self .mock_download .call_args .kwargs ["raise_exception" ])
121137
122138 def test_creates_subdirectories (self ):
123139 tmp = _make_temp_path (self )
@@ -129,12 +145,6 @@ def test_creates_subdirectories(self):
129145
130146 blob = MagicMock ()
131147 blob .name = "prefix/hash/sub/deep.csv"
132- blob .download_to_filename = MagicMock (
133- side_effect = lambda p : (
134- pathlib .Path (p ).parent .mkdir (parents = True , exist_ok = True )
135- or pathlib .Path (p ).write_text ("data" )
136- )
137- )
138148 mock_bucket .list_blobs .return_value = [blob ]
139149
140150 ref = {
@@ -145,10 +155,57 @@ def test_creates_subdirectories(self):
145155
146156 _download_data (ref , str (target ), mock_client )
147157
148- # The call should include the nested path
149- call_path = blob .download_to_filename .call_args [0 ][0 ]
150- self .assertIn ("sub" , call_path )
151- self .assertTrue (call_path .endswith ("deep.csv" ))
158+ blob_names = self .mock_download .call_args [0 ][1 ]
159+ self .assertEqual (blob_names , ["sub/deep.csv" ])
160+
161+ def test_large_listing_downloads_in_batches (self ):
162+ tmp = _make_temp_path (self )
163+ target = tmp / "output"
164+
165+ mock_client = MagicMock ()
166+ mock_bucket = MagicMock ()
167+ mock_client .bucket .return_value = mock_bucket
168+
169+ num_blobs = _DOWNLOAD_BATCH_SIZE + 5
170+ blobs = []
171+ for i in range (num_blobs ):
172+ blob = MagicMock ()
173+ blob .name = f"prefix/hash/file_{ i } .csv"
174+ blobs .append (blob )
175+ mock_bucket .list_blobs .return_value = blobs
176+
177+ ref = {
178+ "__data_ref__" : True ,
179+ "gcs_uri" : "gs://bucket/prefix/hash" ,
180+ "is_dir" : True ,
181+ }
182+
183+ _download_data (ref , str (target ), mock_client )
184+
185+ self .assertEqual (self .mock_download .call_count , 2 )
186+ first_batch = self .mock_download .call_args_list [0 ][0 ][1 ]
187+ second_batch = self .mock_download .call_args_list [1 ][0 ][1 ]
188+ self .assertEqual (len (first_batch ), _DOWNLOAD_BATCH_SIZE )
189+ self .assertEqual (len (second_batch ), 5 )
190+
191+ def test_empty_listing_is_noop (self ):
192+ tmp = _make_temp_path (self )
193+ target = tmp / "output"
194+
195+ mock_client = MagicMock ()
196+ mock_bucket = MagicMock ()
197+ mock_client .bucket .return_value = mock_bucket
198+ mock_bucket .list_blobs .return_value = []
199+
200+ ref = {
201+ "__data_ref__" : True ,
202+ "gcs_uri" : "gs://bucket/prefix/hash" ,
203+ "is_dir" : True ,
204+ }
205+
206+ _download_data (ref , str (target ), mock_client )
207+
208+ self .mock_download .assert_not_called ()
152209
153210
154211class TestResolveDataRefs (absltest .TestCase ):
@@ -200,16 +257,6 @@ def test_nested_refs_in_list(self):
200257
201258 def test_single_file_returns_file_path (self ):
202259 tmp = _make_temp_path (self )
203- mock_client = MagicMock ()
204- mock_bucket = MagicMock ()
205- mock_client .bucket .return_value = mock_bucket
206-
207- blob = MagicMock ()
208- blob .name = "prefix/hash/config.json"
209- blob .download_to_filename = MagicMock (
210- side_effect = lambda p : pathlib .Path (p ).write_text ("{}" )
211- )
212- mock_bucket .list_blobs .return_value = [blob ]
213260
214261 ref = {
215262 "__data_ref__" : True ,
@@ -218,14 +265,55 @@ def test_single_file_returns_file_path(self):
218265 "mount_path" : None ,
219266 }
220267
221- with mock .patch (
222- "keras_remote.runner.remote_runner.DATA_DIR" ,
223- str (tmp / "data" ),
268+ def fake_dl (ref , target_dir , client ):
269+ os .makedirs (target_dir , exist_ok = True )
270+ pathlib .Path (os .path .join (target_dir , "config.json" )).write_text ("{}" )
271+
272+ with (
273+ mock .patch (
274+ "keras_remote.runner.remote_runner.DATA_DIR" ,
275+ str (tmp / "data" ),
276+ ),
277+ mock .patch (
278+ "keras_remote.runner.remote_runner._download_data" ,
279+ side_effect = fake_dl ,
280+ ),
224281 ):
225- args , _ = resolve_data_refs ((ref ,), {}, mock_client )
282+ args , _ = resolve_data_refs ((ref ,), {}, MagicMock () )
226283
227284 self .assertTrue (args [0 ].endswith ("config.json" ))
228285
286+ def test_duplicate_uri_downloaded_once (self ):
287+ tmp = _make_temp_path (self )
288+
289+ ref = {
290+ "__data_ref__" : True ,
291+ "gcs_uri" : "gs://b/cache/hash" ,
292+ "is_dir" : True ,
293+ "mount_path" : None ,
294+ }
295+
296+ def fake_dl (r , target_dir , client ):
297+ os .makedirs (target_dir , exist_ok = True )
298+
299+ with (
300+ mock .patch (
301+ "keras_remote.runner.remote_runner.DATA_DIR" ,
302+ str (tmp / "data" ),
303+ ),
304+ mock .patch (
305+ "keras_remote.runner.remote_runner._download_data" ,
306+ side_effect = fake_dl ,
307+ ) as mock_dl ,
308+ ):
309+ args , kwargs = resolve_data_refs ((ref , ref ), {"d" : ref }, MagicMock ())
310+
311+ # Downloaded only once despite three references
312+ mock_dl .assert_called_once ()
313+ # All resolved paths point to the same directory
314+ self .assertEqual (args [0 ], args [1 ])
315+ self .assertEqual (args [0 ], kwargs ["d" ])
316+
229317 def test_non_ref_dict_preserved (self ):
230318 mock_client = MagicMock ()
231319 args , kwargs = resolve_data_refs (({"key" : "value" },), {"x" : 1 }, mock_client )
0 commit comments