diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 686faa03d36..37a17b56b37 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -7,6 +7,7 @@ import huggingface_hub from fsspec.core import url_to_fs +from fsspec.utils import glob_translate from huggingface_hub import HfFileSystem from packaging import version from tqdm.contrib.concurrent import thread_map @@ -475,8 +476,31 @@ def get_data_patterns(base_path: str, download_config: Optional[DownloadConfig] In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. """ - resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config) try: + resolver: Callable[[str], list[str]] + if is_relative_path(base_path) or is_local_path(base_path): + # Local fast-path: resolve local files once and match candidate patterns in-memory. + # This avoids repeated expensive recursive globs over the same directory tree. + all_local_files = resolve_pattern("**", base_path=base_path, download_config=download_config) + + def resolver(pattern: str) -> list[str]: + if is_relative_path(pattern): + absolute_pattern = xjoin(base_path, pattern) + elif is_local_path(pattern): + absolute_pattern = pattern + else: + raise FileNotFoundError(f"Unable to find '{pattern}'") + absolute_pattern, _ = _prepare_path_and_storage_options( + absolute_pattern, download_config=download_config + ) + regex = re.compile(glob_translate(absolute_pattern)) + matched_paths = [path for path in all_local_files if regex.match(path)] + if not matched_paths: + raise FileNotFoundError(f"Unable to find '{absolute_pattern}'") + return matched_paths + + else: + resolver = partial(resolve_pattern, base_path=base_path, download_config=download_config) return _get_data_files_patterns(resolver) except FileNotFoundError: raise EmptyDatasetError(f"The directory at {base_path} doesn't contain any data files") from None diff --git a/tests/test_data_files.py b/tests/test_data_files.py index 74f48dbd2d5..7c7165a3500 100644 --- a/tests/test_data_files.py +++ b/tests/test_data_files.py @@ -678,3 +678,11 @@ def test_get_data_patterns_from_directory_with_the_word_data_twice(tmp_path): data_file.touch() data_file_patterns = get_data_patterns(repo_dir.as_posix()) assert data_file_patterns == {"train": ["data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*"]} + + +def test_get_data_patterns_local_path_uses_single_resolve(complex_data_dir): + with patch("datasets.data_files.resolve_pattern", wraps=resolve_pattern) as mocked_resolve_pattern: + patterns = get_data_patterns(complex_data_dir) + assert list(patterns.keys()) == ["train", "test"] + assert patterns["train"] and patterns["test"] + assert mocked_resolve_pattern.call_count == 1