1212
1313from marin .download .huggingface .download_hf import (
1414 DownloadConfig ,
15- _get_expected_file_count ,
1615 _relative_path_in_source ,
1716 download_hf ,
1817 stream_file_to_fsspec ,
@@ -62,82 +61,6 @@ def mock_open(path, mode="rb", **_kwargs):
6261 return _create
6362
6463
65- def test_get_expected_file_count_returns_count ():
66- """_get_expected_file_count returns the number of files from HfApi."""
67- cfg = DownloadConfig (
68- hf_dataset_id = "test-org/test-dataset" ,
69- revision = "abc1234" ,
70- )
71- repo_files = ["data/file1.txt" , "data/file2.txt" , "README.md" ]
72- mock_api = MagicMock ()
73- mock_api .list_repo_files .return_value = repo_files
74-
75- with patch ("marin.download.huggingface.download_hf.HfApi" , return_value = mock_api ):
76- result = _get_expected_file_count (cfg )
77-
78- mock_api .list_repo_files .assert_called_once_with ("test-org/test-dataset" , repo_type = "dataset" , revision = "abc1234" )
79- assert result == 3
80-
81-
82- def test_download_hf_cross_references_find_with_list_repo_files (mock_hf_fs , tmp_path ):
83- """download_hf uses hf_fs.find() but cross-references count with list_repo_files()."""
84- test_files = {
85- "datasets/test-org/test-dataset/data/file1.txt" : b"Content 1" ,
86- }
87- hf_fs = mock_hf_fs (test_files )
88- # list_repo_files returns the same count as find — no truncation
89- repo_files = ["data/file1.txt" ]
90- mock_api = MagicMock ()
91- mock_api .list_repo_files .return_value = repo_files
92-
93- output_path = tmp_path / "output"
94- output_path .mkdir ()
95- cfg = DownloadConfig (
96- hf_dataset_id = "test-org/test-dataset" ,
97- revision = "abc1234" ,
98- gcs_output_path = str (output_path ),
99- )
100-
101- with (
102- patch ("marin.download.huggingface.download_hf.HfFileSystem" , return_value = hf_fs ),
103- patch ("marin.download.huggingface.download_hf.HfApi" , return_value = mock_api ),
104- ):
105- download_hf (cfg )
106-
107- # find() SHOULD be called — it's the primary listing method
108- hf_fs .find .assert_called_once ()
109- # list_repo_files is called for cross-reference
110- mock_api .list_repo_files .assert_called_once ()
111- assert (output_path / "data" / "file1.txt" ).exists ()
112-
113-
114- def test_download_hf_raises_on_truncated_find (mock_hf_fs , tmp_path ):
115- """download_hf raises RuntimeError when find() returns fewer files than list_repo_files()."""
116- test_files = {
117- "datasets/test-org/test-dataset/data/file1.txt" : b"Content 1" ,
118- }
119- hf_fs = mock_hf_fs (test_files )
120- # list_repo_files reports more files than find() returned — truncation detected
121- repo_files = ["data/file1.txt" , "data/file2.txt" , "data/file3.txt" ]
122- mock_api = MagicMock ()
123- mock_api .list_repo_files .return_value = repo_files
124-
125- output_path = tmp_path / "output"
126- output_path .mkdir ()
127- cfg = DownloadConfig (
128- hf_dataset_id = "test-org/test-dataset" ,
129- revision = "abc1234" ,
130- gcs_output_path = str (output_path ),
131- )
132-
133- with (
134- patch ("marin.download.huggingface.download_hf.HfFileSystem" , return_value = hf_fs ),
135- patch ("marin.download.huggingface.download_hf.HfApi" , return_value = mock_api ),
136- ):
137- with pytest .raises (RuntimeError , match = "pagination bug" ):
138- download_hf (cfg )
139-
140-
14164def test_download_hf_basic (mock_hf_fs , tmp_path ):
14265 """Test basic HF download functionality."""
14366 test_files = {
@@ -147,8 +70,6 @@ def test_download_hf_basic(mock_hf_fs, tmp_path):
14770 }
14871
14972 hf_fs = mock_hf_fs (test_files )
150- mock_api = MagicMock ()
151- mock_api .list_repo_files .return_value = ["data/file1.txt" , "data/file2.txt" , "README.md" ]
15273
15374 output_path = tmp_path / "output"
15475 output_path .mkdir ()
@@ -161,7 +82,7 @@ def test_download_hf_basic(mock_hf_fs, tmp_path):
16182
16283 with (
16384 patch ("marin.download.huggingface.download_hf.HfFileSystem" , return_value = hf_fs ),
164- patch ("marin.download.huggingface.download_hf.HfApi " , return_value = mock_api ),
85+ patch ("marin.download.huggingface.download_hf._get_expected_file_count " , return_value = None ),
16586 ):
16687 download_hf (cfg )
16788
@@ -193,8 +114,6 @@ def test_download_hf_appends_sha_when_configured(mock_hf_fs, tmp_path):
193114 }
194115
195116 hf_fs = mock_hf_fs (test_files )
196- mock_api = MagicMock ()
197- mock_api .list_repo_files .return_value = ["data/file1.txt" ]
198117
199118 base_output_path = tmp_path / "output"
200119 revision = "abc1234"
@@ -208,7 +127,7 @@ def test_download_hf_appends_sha_when_configured(mock_hf_fs, tmp_path):
208127
209128 with (
210129 patch ("marin.download.huggingface.download_hf.HfFileSystem" , return_value = hf_fs ),
211- patch ("marin.download.huggingface.download_hf.HfApi " , return_value = mock_api ),
130+ patch ("marin.download.huggingface.download_hf._get_expected_file_count " , return_value = None ),
212131 ):
213132 download_hf (cfg )
214133
0 commit comments