diff --git a/lib/marin/src/marin/processing/tokenize/tokenize.py b/lib/marin/src/marin/processing/tokenize/tokenize.py index b69fe403af..581f84f341 100644 --- a/lib/marin/src/marin/processing/tokenize/tokenize.py +++ b/lib/marin/src/marin/processing/tokenize/tokenize.py @@ -16,8 +16,10 @@ import time from collections.abc import Iterator, Sequence +import braceexpand import draccus -from rigging.filesystem import open_url +import fsspec +from rigging.filesystem import open_url, url_to_fs from datasets import load_dataset_builder from fray.v2 import ResourceConfig from levanter.data.text import ( @@ -32,10 +34,11 @@ from levanter.store.cache import consolidate_shard_caches from levanter.store.tree_store import TreeStore from zephyr import Dataset, ZephyrContext, zephyr_worker_ctx -from zephyr.readers import load_file +from zephyr.dataset import FileEntry +from zephyr.readers import InputFileSpec, load_file -from marin.execution.executor import ExecutorStep, InputName, VersionedValue -from marin.utils import fsspec_exists, fsspec_glob, fsspec_isdir, fsspec_size +from marin.execution.executor import InputName, VersionedValue +from marin.utils import fsspec_exists, fsspec_isdir from rigging.log_setup import configure_logging logger = logging.getLogger(__name__) @@ -205,60 +208,73 @@ def _validate_train_urls(train_paths: list[str | InputName], warn): ) -def _get_files_by_extensions(input_paths: list[str], extensions: list[str]) -> list[str]: - """ - Get a list of all filepaths with the specified extension from the input paths. - """ - output_paths = [] - for path in input_paths: - assert path != "/" - if path.endswith("/") or fsspec_isdir(path): - logger.info(f"Getting all {extensions} files in {path}") - for ex in extensions: - output_paths.extend(fsspec_glob(os.path.join(path, f"**/*.{ex}"))) - else: - output_paths.extend(fsspec_glob(path)) +_TOKENIZE_EXTENSIONS = ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet"] - return output_paths +# NOTE(chris): Marin's `default_download` writes a `provenance.json` sidecar next to +# downloaded HF data. Downstream TokenizeConfig jobs glob those directories and must +# exclude sidecars so we don't train on provenance records. Applied uniformly to both +# splits and both config types — HF hub datasets don't ship sidecars named this way, +# so the filter is a no-op on the HfTokenizeConfig path. +_MARIN_SIDECAR_NAMES = frozenset({"provenance.json"}) -def _get_filepaths_to_tokenize(input_paths: list[str]) -> list[str]: + +def _drop_sidecars(files: list[FileEntry]) -> list[FileEntry]: + return [f for f in files if os.path.basename(f.path) not in _MARIN_SIDECAR_NAMES] + + +def _glob_with_sizes(patterns: list[str]) -> list[FileEntry]: + """Glob patterns and return FileEntry objects (spec + size). + + Uses fsspec glob(detail=True) which returns file metadata from the same + list-objects API call — no per-file stat RPCs needed. Works for gs://, hf://, s3://, local. """ - Get all file paths to tokenize from the input paths. - Handles json/jsonl.{gz,zst,zstd}, and parquet. + results: list[FileEntry] = [] + for pattern in patterns: + pattern = re.sub(r"(? list[str]: + """Expand input paths into glob patterns for tokenizable file types. + + Directories get expanded to recursive globs for each supported extension. + Concrete paths/patterns pass through unchanged. """ if isinstance(input_paths, VersionedValue): input_paths = input_paths.value - if len(input_paths) == 0: - return [] - elif any(isinstance(x, InputName | ExecutorStep) for x in input_paths): - return input_paths - - out = _get_files_by_extensions(input_paths, ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet", "json"]) - out = [x for x in out if "provenance.json" not in x] - - if not len(out): - raise ValueError( - f"No valid jsonl or parquet files found in {input_paths}. " - "Please provide a path to a directory containing jsonl or parquet files." - ) - - return out + patterns: list[str] = [] + for path in input_paths: + assert path != "/" + if path.endswith("/") or fsspec_isdir(path): + logger.info(f"Getting all {_TOKENIZE_EXTENSIONS} files in {path}") + for ex in _TOKENIZE_EXTENSIONS: + patterns.append(os.path.join(path, f"**/*.{ex}")) + else: + patterns.append(path) + return patterns -def _bundle_files_by_size(file_infos, max_bytes: int): +def _bundle_files_by_size(files: list[FileEntry], max_bytes: int): """Bundle files into groups, with each group having a total size less than max_bytes.""" - current_group = [] + current_group: list[str] = [] current_size = 0 - for info in file_infos: - if current_size + info["size"] >= max_bytes and current_group: + for f in files: + if current_size + f.size >= max_bytes and current_group: yield current_group current_group = [] current_size = 0 - current_group.append(info["filename"]) - current_size += info["size"] + current_group.append(f.path) + current_size += f.size if current_group: yield current_group @@ -312,10 +328,8 @@ def tokenize(config: TokenizeConfigBase): """ if isinstance(config, TokenizeConfig): - train_paths = _get_filepaths_to_tokenize(config.train_paths) if config.train_paths else [] - validation_paths = _get_filepaths_to_tokenize(config.validation_paths) if config.validation_paths else [] - # Validate expanded paths to catch validation/test files that were inside directories - _validate_train_urls(train_paths, warn=config.allow_test_in_train) + train_patterns = _expand_tokenize_paths(config.train_paths) if config.train_paths else [] + validation_patterns = _expand_tokenize_paths(config.validation_paths) if config.validation_paths else [] elif isinstance(config, HfTokenizeConfig): logger.info(f"Loading dataset metadata for {config.id}" + (f" (config: {config.name})" if config.name else "")) @@ -328,45 +342,42 @@ def tokenize(config: TokenizeConfigBase): "This might be a dataset that requires custom loading logic." ) - train_paths = data_files.get("train", []) - validation_paths = data_files.get("validation", data_files.get("test", [])) - - if train_paths: - logger.info(f"Found {len(train_paths)} training files in {config.id}") - if validation_paths: - logger.info(f"Found {len(validation_paths)} validation files in {config.id}") + train_patterns = list(data_files.get("train", [])) + validation_patterns = list(data_files.get("validation", data_files.get("test", []))) else: raise ValueError(f"Unknown config type: {type(config)}") - if not train_paths and not validation_paths: + # Resolve patterns → concrete files with sizes (single list-objects call per pattern) + train_files = _drop_sidecars(_glob_with_sizes(train_patterns)) + validation_files = _drop_sidecars(_glob_with_sizes(validation_patterns)) + + if isinstance(config, TokenizeConfig): + _validate_train_urls([f.path for f in train_files], warn=config.allow_test_in_train) + + if train_files: + logger.info(f"Found {len(train_files)} training files") + if validation_files: + logger.info(f"Found {len(validation_files)} validation files") + + if train_patterns and not train_files: + raise ValueError(f"No training files matched configured patterns: {train_patterns}") + if validation_patterns and not validation_files: + raise ValueError(f"No validation files matched configured patterns: {validation_patterns}") + if not train_files and not validation_files: raise ValueError("No input files specified. Nothing to do.") - def local_preprocess_paths(paths: list[str]) -> list[list[str]]: - """Scan file sizes locally and bundle into groups for distributed processing.""" - filescan_start = time.monotonic() - # Sort for deterministic batching, then chunk into groups of 64. - paths = sorted(paths) - batched_paths = [paths[i : i + 64] for i in range(0, len(paths), 64)] - scan_ctx = ZephyrContext( - max_workers=32, - resources=ResourceConfig(cpu=1, ram="1g"), - name="tokenize-filescan", - ) - file_stats = scan_ctx.execute( - Dataset.from_list(batched_paths).flat_map( - lambda batch: [{"filename": p, "size": fsspec_size(p)} for p in batch] - ), - verbose=False, - ).results - total_input_bytes = sum(f["size"] for f in file_stats) + def local_preprocess_paths(files: list[FileEntry]) -> list[list[str]]: + """Bundle files into size-balanced groups for distributed processing.""" + files = sorted(files, key=lambda f: f.path) + total_input_bytes = sum(f.size for f in files) if config.num_shards is not None: target_group_bytes = _compute_target_group_bytes(total_input_bytes, config.num_shards) else: target_group_bytes = _compute_target_group_bytes(total_input_bytes, config.max_workers) - file_groups = list(_bundle_files_by_size(file_stats, target_group_bytes)) + file_groups = list(_bundle_files_by_size(files, target_group_bytes)) logger.info( - f"Grouped {len(paths):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups " - f"(target {target_group_bytes / 1e9:.2f} GB/group) in {time.monotonic() - filescan_start:.1f}s." + f"Grouped {len(files):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups " + f"(target {target_group_bytes / 1e9:.2f} GB/group)." ) return file_groups @@ -451,8 +462,8 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s ) # TODO (rav): both train and val could run at the same time - if train_paths and not split_already_done("train"): - train_groups = local_preprocess_paths(train_paths) + if train_files and not split_already_done("train"): + train_groups = local_preprocess_paths(train_files) ctx = ZephyrContext( resources=config.worker_resources, max_workers=min(config.max_workers, len(train_groups)), @@ -460,8 +471,8 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s ) run_pipeline(ctx, train_groups, "train") - if validation_paths and not split_already_done("validation"): - validation_groups = local_preprocess_paths(validation_paths) + if validation_files and not split_already_done("validation"): + validation_groups = local_preprocess_paths(validation_files) ctx = ZephyrContext( resources=config.worker_resources, max_workers=min(config.max_workers, len(validation_groups)), diff --git a/lib/zephyr/src/zephyr/dataset.py b/lib/zephyr/src/zephyr/dataset.py index 22ff7a713f..f69d31562d 100644 --- a/lib/zephyr/src/zephyr/dataset.py +++ b/lib/zephyr/src/zephyr/dataset.py @@ -16,10 +16,65 @@ from rigging.filesystem import url_to_fs from zephyr.expr import Expr +from zephyr.readers import InputFileSpec logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class GlobSource: + """Lazy file source resolved at plan time via a bulk list-objects call. + + Stores the glob pattern and defers expansion to compute_plan(), where + fsspec glob(detail=True) returns paths and sizes in a single RPC. + """ + + pattern: str + empty_glob_ok: bool = False + + +@dataclass(frozen=True) +class FileEntry: + """A discovered input file: read-spec plus metadata from the bulk listing. + + ``spec`` is the pure read-specification (how to read the file); ``size`` is + discovered metadata (from glob ``detail=True``). Keeping these separate means + readers only depend on ``InputFileSpec`` while planners can still size shards. + """ + + spec: InputFileSpec + size: int + + @property + def path(self) -> str: + return self.spec.path + + +def resolve_glob(source: GlobSource) -> list[FileEntry]: + """Expand a GlobSource into FileEntry objects with sizes. + + Uses fsspec glob(detail=True) which returns file metadata from the same + list-objects API call — no extra per-file stat RPCs. + """ + pattern = re.sub(r"(?>> output_files = ctx.execute(ds).results """ - # Normalize double slashes while preserving protocol (e.g., gs://, s3://, http://) - pattern = re.sub(r"(? Dataset[R]: """Map a function over the dataset. diff --git a/lib/zephyr/src/zephyr/plan.py b/lib/zephyr/src/zephyr/plan.py index 842435529e..22b4b0c628 100644 --- a/lib/zephyr/src/zephyr/plan.py +++ b/lib/zephyr/src/zephyr/plan.py @@ -29,8 +29,10 @@ from zephyr.dataset import ( Dataset, + FileEntry, FilterOp, FlatMapOp, + GlobSource, GroupByOp, JoinOp, LoadFileOp, @@ -43,6 +45,7 @@ TakePerShardOp, WindowOp, WriteOp, + resolve_glob, ) from zephyr.expr import Expr from zephyr.readers import InputFileSpec @@ -470,14 +473,14 @@ def _fuse_operations(operations: list) -> list[PhysicalStage]: def _compute_file_pushdown( - paths: list[str], + files: list[FileEntry], load_op: LoadFileOp, operations: list, ) -> tuple[list[SourceItem], list]: """Create source items for file pipeline with pushdown optimizations applied. Args: - paths: List of file paths to load + files: List of FileEntry objects (path + size from bulk listing) load_op: The LoadFileOp specifying format and default columns operations: Full operations list (first op is LoadFileOp) @@ -510,13 +513,13 @@ def _compute_file_pushdown( SourceItem( shard_idx=i, data=InputFileSpec( - path=path, + path=entry.path, format=load_op.format, columns=select_columns, filter_expr=filter_expr, ), ) - for i, path in enumerate(paths) + for i, entry in enumerate(files) ] # Build final operations list: LoadFileOp + remaining ops @@ -528,15 +531,30 @@ def _compute_file_pushdown( def compute_plan(dataset: Dataset) -> PhysicalPlan: """Compute physical execution plan from logical dataset.""" operations = list(dataset.operations) - - if operations and isinstance(operations[0], LoadFileOp): + source = dataset.source + + # Resolve lazy glob sources into concrete FileEntry objects (with sizes). + if isinstance(source, GlobSource): + file_entries = resolve_glob(source) + if operations and isinstance(operations[0], LoadFileOp): + source_items, operations = _compute_file_pushdown( + file_entries, + operations[0], + operations[1:], + ) + else: + # from_files() without load_file() — source items are plain paths + source_items = [SourceItem(shard_idx=i, data=entry.path) for i, entry in enumerate(file_entries)] + elif operations and isinstance(operations[0], LoadFileOp): + # Non-glob source (e.g. from_list of paths) — wrap as FileEntry without sizes + entries = [FileEntry(spec=InputFileSpec(path=p), size=0) for p in source] source_items, operations = _compute_file_pushdown( - list(dataset.source), + entries, operations[0], operations[1:], ) else: - source_list = list(dataset.source) + source_list = list(source) source_items = [SourceItem(shard_idx=i, data=item) for i, item in enumerate(source_list)] stages = _fuse_operations(operations) diff --git a/lib/zephyr/src/zephyr/readers.py b/lib/zephyr/src/zephyr/readers.py index c49d72e7eb..a9ee8ef415 100644 --- a/lib/zephyr/src/zephyr/readers.py +++ b/lib/zephyr/src/zephyr/readers.py @@ -121,6 +121,9 @@ def iter_parquet_row_groups( class InputFileSpec: """Specification for reading a file or portion of a file. + Pure read-spec: everything here is caller-supplied. Discovered metadata + (e.g. file size from a bulk listing) lives on ``FileEntry`` instead. + Attributes: path: Path to the file format: File format ("parquet", "jsonl", or "auto" to detect) diff --git a/lib/zephyr/tests/test_dataset.py b/lib/zephyr/tests/test_dataset.py index 2aed9f1f89..da37750c8f 100644 --- a/lib/zephyr/tests/test_dataset.py +++ b/lib/zephyr/tests/test_dataset.py @@ -13,7 +13,7 @@ from fray.v2.local_backend import LocalClient from zephyr import Dataset, load_file, load_parquet from zephyr._test_helpers import SampleDataclass -from zephyr.dataset import FilterOp, MapOp, WindowOp +from zephyr.dataset import FilterOp, GlobSource, MapOp, WindowOp, resolve_glob from zephyr.execution import ZephyrContext from zephyr.writers import write_parquet_file @@ -280,28 +280,24 @@ def test_operations_are_dataclasses(): def test_from_files_basic(tmp_path): """Test basic file globbing.""" - # Create test input files input_dir = tmp_path / "input" input_dir.mkdir() (input_dir / "file1.txt").write_text("data1") (input_dir / "file2.txt").write_text("data2") (input_dir / "file3.txt").write_text("data3") - # Create dataset ds = Dataset.from_files(f"{input_dir}/*.txt") - files = list(ds.source) # Access source directly without backend execution - - assert len(files) == 3 - assert all(isinstance(f, str) for f in files) + assert isinstance(ds.source, GlobSource) - # Check that all input files are from input_dir - assert all(str(input_dir) in f for f in files) - assert all(f.endswith(".txt") for f in files) + entries = resolve_glob(ds.source) + assert len(entries) == 3 + assert all(str(input_dir) in e.path for e in entries) + assert all(e.path.endswith(".txt") for e in entries) + assert all(e.size > 0 for e in entries) def test_from_files_nested(tmp_path): """Test from_files with nested directories.""" - # Create nested structure input_dir = tmp_path / "input" (input_dir / "subdir1").mkdir(parents=True) (input_dir / "subdir2").mkdir(parents=True) @@ -310,15 +306,12 @@ def test_from_files_nested(tmp_path): (input_dir / "subdir1" / "file2.txt").write_text("data2") (input_dir / "subdir2" / "file3.txt").write_text("data3") - # Use ** pattern to match nested files ds = Dataset.from_files(f"{input_dir}/**/*.txt") - files = list(ds.source) - - assert len(files) == 3 + entries = resolve_glob(ds.source) - # Check that nested structure is in file paths - assert any("subdir1" in path for path in files) - assert any("subdir2" in path for path in files) + assert len(entries) == 3 + assert any("subdir1" in e.path for e in entries) + assert any("subdir2" in e.path for e in entries) def test_from_files_empty_glob_ok(tmp_path): @@ -326,10 +319,9 @@ def test_from_files_empty_glob_ok(tmp_path): input_dir = tmp_path / "input" input_dir.mkdir() - # No error when empty_glob_ok=True ds = Dataset.from_files(f"{input_dir}/*.txt", empty_glob_ok=True) - files = list(ds.source) - assert len(files) == 0 + entries = resolve_glob(ds.source) + assert len(entries) == 0 def test_from_files_empty_glob_error(tmp_path): @@ -337,9 +329,9 @@ def test_from_files_empty_glob_error(tmp_path): input_dir = tmp_path / "input" input_dir.mkdir() - # Should raise FileNotFoundError + ds = Dataset.from_files(f"{input_dir}/*.txt", empty_glob_ok=False) with pytest.raises(FileNotFoundError, match="No files found"): - Dataset.from_files(f"{input_dir}/*.txt", empty_glob_ok=False) + resolve_glob(ds.source) def test_from_files_with_map(tmp_path, zephyr_ctx): diff --git a/tests/processing/tokenize/test_tokenize.py b/tests/processing/tokenize/test_tokenize.py index 9177f7e2c7..1ce2fbf862 100644 --- a/tests/processing/tokenize/test_tokenize.py +++ b/tests/processing/tokenize/test_tokenize.py @@ -18,6 +18,8 @@ _compute_target_group_bytes, tokenize, ) +from zephyr.dataset import FileEntry +from zephyr.readers import InputFileSpec # Dummy values for other required TokenizeConfig fields DUMMY_CACHE_PATH = "/dummy/cache" @@ -139,14 +141,18 @@ def test_compute_target_group_bytes(total_bytes, max_workers, expected): assert _compute_target_group_bytes(total_bytes, max_workers) == expected +def _fe(path: str, size: int) -> FileEntry: + return FileEntry(spec=InputFileSpec(path=path), size=size) + + def test_bundle_files_produces_expected_groups(): """Auto-computed grouping should produce approximately max_workers groups.""" - file_infos = [{"filename": f"file_{i}.jsonl", "size": 500_000_000} for i in range(20)] - total_bytes = sum(f["size"] for f in file_infos) # 10 GB total + files = [_fe(f"file_{i}.jsonl", 500_000_000) for i in range(20)] + total_bytes = sum(f.size for f in files) # 10 GB total max_workers = 4 target = _compute_target_group_bytes(total_bytes, max_workers) # 2.5 GB per group - groups = list(_bundle_files_by_size(file_infos, target)) + groups = list(_bundle_files_by_size(files, target)) # _bundle_files_by_size yields a group when adding the next file would reach # the target (uses >=). With target=2.5 GB and 500 MB files, each group fits # 4 files (2 GB < 2.5 GB), yielding 5 groups. @@ -157,13 +163,13 @@ def test_bundle_files_produces_expected_groups(): def test_bundle_files_single_large_file(): """A single file larger than target_group_bytes gets its own group.""" - file_infos = [ - {"filename": "big.jsonl", "size": 5_000_000_000}, - {"filename": "small1.jsonl", "size": 100_000_000}, - {"filename": "small2.jsonl", "size": 100_000_000}, + files = [ + _fe("big.jsonl", 5_000_000_000), + _fe("small1.jsonl", 100_000_000), + _fe("small2.jsonl", 100_000_000), ] target = 1_000_000_000 # 1 GB - groups = list(_bundle_files_by_size(file_infos, target)) + groups = list(_bundle_files_by_size(files, target)) assert groups[0] == ["big.jsonl"] assert groups[1] == ["small1.jsonl", "small2.jsonl"]