diff --git a/lib/marin/src/marin/download/huggingface/download_hf.py b/lib/marin/src/marin/download/huggingface/download_hf.py index 089ef63e0c..cc8d6b606a 100644 --- a/lib/marin/src/marin/download/huggingface/download_hf.py +++ b/lib/marin/src/marin/download/huggingface/download_hf.py @@ -16,7 +16,7 @@ import draccus import huggingface_hub -from huggingface_hub import HfFileSystem +from huggingface_hub import HfApi, HfFileSystem from iris.marin_fs import open_url, url_to_fs from huggingface_hub.errors import HfHubHTTPError from packaging.version import Version @@ -85,6 +85,27 @@ def _resolve_hf_source_path(cfg: DownloadConfig) -> str: return _strip_hf_protocol(source_path) +REPO_TYPE_PREFIX_TO_API_TYPE: dict[str, str] = { + "datasets": "dataset", + "spaces": "space", +} + + +def _get_expected_file_count(cfg: DownloadConfig) -> int | None: + """Return the expected file count from HfApi.list_repo_files() for truncation detection. + + HfFileSystem.find() has a pagination bug that silently truncates results on large repos. + This function uses HfApi.list_repo_files() (which handles pagination correctly) to get the + true file count so callers can detect truncation. + + Returns None for bucket paths, which are not standard HF repos. + """ + api = HfApi(token=os.environ.get("HF_TOKEN", False)) + repo_type = REPO_TYPE_PREFIX_TO_API_TYPE.get(cfg.hf_repo_type_prefix, "model") + repo_files = list(api.list_repo_files(cfg.hf_dataset_id, repo_type=repo_type, revision=cfg.revision)) + return len(repo_files) + + def _assert_bucket_support_available(source_path: str) -> None: if not source_path.startswith(HF_BUCKET_PATH_PREFIX): return @@ -274,8 +295,19 @@ def download_hf(cfg: DownloadConfig) -> None: _assert_bucket_support_available(hf_source_path) if not cfg.hf_urls_glob: - # We get all the files using find files = hf_fs.find(hf_source_path, revision=cfg.revision) + + # Cross-reference file count with HfApi.list_repo_files() to detect + # silent truncation from a pagination bug in HfFileSystem.find(). + if not hf_source_path.startswith(HF_BUCKET_PATH_PREFIX): + expected_count = _get_expected_file_count(cfg) + if expected_count is not None and len(files) != expected_count: + raise RuntimeError( + f"HfFileSystem.find() returned {len(files)} files but HfApi.list_repo_files() " + f"reports {expected_count} files for {cfg.hf_dataset_id}@{cfg.revision}. " + f"This is likely due to a pagination bug in HfFileSystem.find(). " + f"Consider upgrading huggingface_hub or filing a bug report." + ) else: # Get list of files directly from HfFileSystem matching the pattern files = [] diff --git a/tests/download/test_huggingface.py b/tests/download/test_huggingface.py index 1019c83633..0f7a715088 100644 --- a/tests/download/test_huggingface.py +++ b/tests/download/test_huggingface.py @@ -80,8 +80,10 @@ def test_download_hf_basic(mock_hf_fs, tmp_path): gcs_output_path=str(output_path), ) - # Mock HfFileSystem creation - with patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs): + with ( + patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs), + patch("marin.download.huggingface.download_hf._get_expected_file_count", return_value=None), + ): download_hf(cfg) # Verify files were downloaded @@ -123,7 +125,10 @@ def test_download_hf_appends_sha_when_configured(mock_hf_fs, tmp_path): append_sha_to_path=True, ) - with patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs): + with ( + patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs), + patch("marin.download.huggingface.download_hf._get_expected_file_count", return_value=None), + ): download_hf(cfg) target_output = base_output_path / revision