Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import huggingface_hub
from fsspec.core import url_to_fs
from fsspec.utils import glob_translate
from huggingface_hub import HfFileSystem
from packaging import version
from tqdm.contrib.concurrent import thread_map
Expand Down Expand Up @@ -475,8 +476,31 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig]

In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
"""
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
try:
resolver: Callable[[str], list[str]]
if is_relative_path(base_path) or is_local_path(base_path):
# Local fast-path: resolve local files once and match candidate patterns in-memory.
# This avoids repeated expensive recursive globs over the same directory tree.
all_local_files = resolve_pattern("**", base_path=base_path, download_config=download_config)

def resolver(pattern: str) -> list[str]:
if is_relative_path(pattern):
absolute_pattern = xjoin(base_path, pattern)
elif is_local_path(pattern):
absolute_pattern = pattern
else:
raise FileNotFoundError(f"Unable to find '{pattern}'")
absolute_pattern, _ = _prepare_path_and_storage_options(
absolute_pattern, download_config=download_config
)
regex = re.compile(glob_translate(absolute_pattern))
matched_paths = [path for path in all_local_files if regex.match(path)]
if not matched_paths:
raise FileNotFoundError(f"Unable to find '{absolute_pattern}'")
return matched_paths

else:
resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config)
return _get_data_files_patterns(resolver)
except FileNotFoundError:
raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None
Expand Down
8 changes: 8 additions & 0 deletions tests/test_data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,3 +678,11 @@ def test_get_data_patterns_from_directory_with_the_word_data_twice(tmp_path):
data_file.touch()
data_file_patterns = get_data_patterns(repo_dir.as_posix())
assert data_file_patterns == {"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]}


def test_get_data_patterns_local_path_uses_single_resolve(complex_data_dir):
with patch("datasets.data_files.resolve_pattern", wraps=resolve_pattern) as mocked_resolve_pattern:
patterns = get_data_patterns(complex_data_dir)
assert list(patterns.keys()) == ["train", "test"]
assert patterns["train"] and patterns["test"]
assert mocked_resolve_pattern.call_count == 1