Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions lib/marin/src/marin/download/huggingface/download_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import draccus
import huggingface_hub
from huggingface_hub import HfFileSystem
from huggingface_hub import HfApi, HfFileSystem
from iris.marin_fs import open_url, url_to_fs
from huggingface_hub.errors import HfHubHTTPError
from packaging.version import Version
Expand Down Expand Up @@ -85,6 +85,27 @@ def _resolve_hf_source_path(cfg: DownloadConfig) -> str:
return _strip_hf_protocol(source_path)


REPO_TYPE_PREFIX_TO_API_TYPE: dict[str, str] = {
"datasets": "dataset",
"spaces": "space",
}


def _get_expected_file_count(cfg: DownloadConfig) -> int | None:
"""Return the expected file count from HfApi.list_repo_files() for truncation detection.

HfFileSystem.find() has a pagination bug that silently truncates results on large repos.
This function uses HfApi.list_repo_files() (which handles pagination correctly) to get the
true file count so callers can detect truncation.

Returns None for bucket paths, which are not standard HF repos.
"""
api = HfApi(token=os.environ.get("HF_TOKEN", False))
repo_type = REPO_TYPE_PREFIX_TO_API_TYPE.get(cfg.hf_repo_type_prefix, "model")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Normalize repo-type prefix before HfApi lookup

_list_repo_files maps cfg.hf_repo_type_prefix by exact key, so values like "datasets/" or "spaces/" fall through to the default "model". This is a regression from the previous hf_fs.find path-based behavior: configs that include the trailing slash (which is still documented in DownloadConfig) will now call list_repo_files(..., repo_type="model") for dataset/space repos and fail to list files (typically 404/not found). Strip trailing / (or otherwise normalize) before the lookup to preserve existing inputs.

Useful? React with 👍 / 👎.

repo_files = list(api.list_repo_files(cfg.hf_dataset_id, repo_type=repo_type, revision=cfg.revision))
return len(repo_files)


def _assert_bucket_support_available(source_path: str) -> None:
if not source_path.startswith(HF_BUCKET_PATH_PREFIX):
return
Expand Down Expand Up @@ -274,8 +295,19 @@ def download_hf(cfg: DownloadConfig) -> None:
_assert_bucket_support_available(hf_source_path)

if not cfg.hf_urls_glob:
# We get all the files using find
files = hf_fs.find(hf_source_path, revision=cfg.revision)

# Cross-reference file count with HfApi.list_repo_files() to detect
# silent truncation from a pagination bug in HfFileSystem.find().
if not hf_source_path.startswith(HF_BUCKET_PATH_PREFIX):
expected_count = _get_expected_file_count(cfg)
if expected_count is not None and len(files) != expected_count:
raise RuntimeError(
f"HfFileSystem.find() returned {len(files)} files but HfApi.list_repo_files() "
f"reports {expected_count} files for {cfg.hf_dataset_id}@{cfg.revision}. "
f"This is likely due to a pagination bug in HfFileSystem.find(). "
f"Consider upgrading huggingface_hub or filing a bug report."
)
else:
# Get list of files directly from HfFileSystem matching the pattern
files = []
Expand Down
11 changes: 8 additions & 3 deletions tests/download/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def test_download_hf_basic(mock_hf_fs, tmp_path):
gcs_output_path=str(output_path),
)

# Mock HfFileSystem creation
with patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs):
with (
patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs),
patch("marin.download.huggingface.download_hf._get_expected_file_count", return_value=None),
):
download_hf(cfg)

# Verify files were downloaded
Expand Down Expand Up @@ -123,7 +125,10 @@ def test_download_hf_appends_sha_when_configured(mock_hf_fs, tmp_path):
append_sha_to_path=True,
)

with patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs):
with (
patch("marin.download.huggingface.download_hf.HfFileSystem", return_value=hf_fs),
patch("marin.download.huggingface.download_hf._get_expected_file_count", return_value=None),
):
download_hf(cfg)

target_output = base_output_path / revision
Expand Down
Loading