-
Notifications
You must be signed in to change notification settings - Fork 104
[zephyr/tokenize] Use bulk list-objects for file sizes, delete filescan job #4658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
cfd9053
98604d2
4ad070f
cbd0aee
ed8da76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,8 +16,10 @@ | |
| import time | ||
| from collections.abc import Iterator, Sequence | ||
|
|
||
| import braceexpand | ||
| import draccus | ||
| from rigging.filesystem import open_url | ||
| import fsspec | ||
| from rigging.filesystem import open_url, url_to_fs | ||
| from datasets import load_dataset_builder | ||
| from fray.v2 import ResourceConfig | ||
| from levanter.data.text import ( | ||
|
|
@@ -34,8 +36,8 @@ | |
| from zephyr import Dataset, ZephyrContext, zephyr_worker_ctx | ||
| from zephyr.readers import load_file | ||
|
|
||
| from marin.execution.executor import ExecutorStep, InputName, VersionedValue | ||
| from marin.utils import fsspec_exists, fsspec_glob, fsspec_isdir, fsspec_size | ||
| from marin.execution.executor import InputName, VersionedValue | ||
| from marin.utils import fsspec_exists, fsspec_isdir | ||
| from rigging.log_setup import configure_logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -205,46 +207,47 @@ def _validate_train_urls(train_paths: list[str | InputName], warn): | |
| ) | ||
|
|
||
|
|
||
| def _get_files_by_extensions(input_paths: list[str], extensions: list[str]) -> list[str]: | ||
| """ | ||
| Get a list of all filepaths with the specified extension from the input paths. | ||
| """ | ||
| output_paths = [] | ||
| for path in input_paths: | ||
| assert path != "/" | ||
| if path.endswith("/") or fsspec_isdir(path): | ||
| logger.info(f"Getting all {extensions} files in {path}") | ||
| for ex in extensions: | ||
| output_paths.extend(fsspec_glob(os.path.join(path, f"**/*.{ex}"))) | ||
| else: | ||
| output_paths.extend(fsspec_glob(path)) | ||
| _TOKENIZE_EXTENSIONS = ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet", "json"] | ||
|
|
||
| return output_paths | ||
|
|
||
| def _glob_with_sizes(patterns: list[str]) -> list[dict]: | ||
| """Glob patterns and return [{"filename": path, "size": bytes}]. | ||
|
|
||
| def _get_filepaths_to_tokenize(input_paths: list[str]) -> list[str]: | ||
| Uses fsspec glob(detail=True) which returns file metadata from the same | ||
| list-objects API call — no per-file stat RPCs needed. Works for gs://, hf://, s3://, local. | ||
| """ | ||
| Get all file paths to tokenize from the input paths. | ||
| Handles json/jsonl.{gz,zst,zstd}, and parquet. | ||
| results: list[dict] = [] | ||
| for pattern in patterns: | ||
| pattern = re.sub(r"(?<!:)//+", "/", pattern) | ||
| fs, _ = url_to_fs(pattern) | ||
| protocol = fsspec.core.split_protocol(pattern)[0] | ||
| for expanded in braceexpand.braceexpand(pattern): | ||
| detail = fs.glob(expanded, detail=True) | ||
| for path, info in detail.items(): | ||
| full = f"{protocol}://{path}" if protocol else path | ||
| results.append({"filename": full, "size": info.get("size", 0)}) | ||
| return results | ||
|
|
||
|
|
||
| def _expand_tokenize_paths(input_paths: list[str]) -> list[str]: | ||
| """Expand input paths into glob patterns for tokenizable file types. | ||
|
|
||
| Directories get expanded to recursive globs for each supported extension. | ||
| Concrete paths/patterns pass through unchanged. | ||
| """ | ||
| if isinstance(input_paths, VersionedValue): | ||
| input_paths = input_paths.value | ||
|
|
||
| if len(input_paths) == 0: | ||
| return [] | ||
| elif any(isinstance(x, InputName | ExecutorStep) for x in input_paths): | ||
| return input_paths | ||
|
|
||
| out = _get_files_by_extensions(input_paths, ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet", "json"]) | ||
| out = [x for x in out if "provenance.json" not in x] | ||
|
|
||
| if not len(out): | ||
| raise ValueError( | ||
| f"No valid jsonl or parquet files found in {input_paths}. " | ||
| "Please provide a path to a directory containing jsonl or parquet files." | ||
| ) | ||
|
|
||
| return out | ||
| patterns: list[str] = [] | ||
| for path in input_paths: | ||
| assert path != "/" | ||
| if path.endswith("/") or fsspec_isdir(path): | ||
| logger.info(f"Getting all {_TOKENIZE_EXTENSIONS} files in {path}") | ||
| for ex in _TOKENIZE_EXTENSIONS: | ||
| patterns.append(os.path.join(path, f"**/*.{ex}")) | ||
| else: | ||
| patterns.append(path) | ||
| return patterns | ||
|
|
||
|
|
||
| def _bundle_files_by_size(file_infos, max_bytes: int): | ||
|
|
@@ -312,10 +315,8 @@ def tokenize(config: TokenizeConfigBase): | |
| """ | ||
|
|
||
| if isinstance(config, TokenizeConfig): | ||
| train_paths = _get_filepaths_to_tokenize(config.train_paths) if config.train_paths else [] | ||
| validation_paths = _get_filepaths_to_tokenize(config.validation_paths) if config.validation_paths else [] | ||
| # Validate expanded paths to catch validation/test files that were inside directories | ||
| _validate_train_urls(train_paths, warn=config.allow_test_in_train) | ||
| train_patterns = _expand_tokenize_paths(config.train_paths) if config.train_paths else [] | ||
| validation_patterns = _expand_tokenize_paths(config.validation_paths) if config.validation_paths else [] | ||
| elif isinstance(config, HfTokenizeConfig): | ||
| logger.info(f"Loading dataset metadata for {config.id}" + (f" (config: {config.name})" if config.name else "")) | ||
|
|
||
|
|
@@ -328,45 +329,40 @@ def tokenize(config: TokenizeConfigBase): | |
| "This might be a dataset that requires custom loading logic." | ||
| ) | ||
|
|
||
| train_paths = data_files.get("train", []) | ||
| validation_paths = data_files.get("validation", data_files.get("test", [])) | ||
|
|
||
| if train_paths: | ||
| logger.info(f"Found {len(train_paths)} training files in {config.id}") | ||
| if validation_paths: | ||
| logger.info(f"Found {len(validation_paths)} validation files in {config.id}") | ||
| train_patterns = list(data_files.get("train", [])) | ||
| validation_patterns = list(data_files.get("validation", data_files.get("test", []))) | ||
| else: | ||
| raise ValueError(f"Unknown config type: {type(config)}") | ||
|
|
||
| if not train_paths and not validation_paths: | ||
| # Resolve patterns → concrete files with sizes (single list-objects call per pattern) | ||
| train_file_stats = _glob_with_sizes(train_patterns) | ||
| train_file_stats = [f for f in train_file_stats if "provenance.json" not in f["filename"]] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "provenance.json" special-cases seems brittle
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah i'll figure out a more appropriate setup
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the old code has this as well. I don't see a great way around it unfortunately. we should really be writing our own metadata into isolated directories like dir/.marin/provenance.json, but of course we have a lot of leftover datasets with this in it. I changed it to use a constant set of known metadata files & filter those out (and we don't look for plain .json anymore which should avoid most of the issues...) |
||
| validation_file_stats = _glob_with_sizes(validation_patterns) | ||
|
rjpower marked this conversation as resolved.
Outdated
|
||
| validation_file_stats = [f for f in validation_file_stats if "provenance.json" not in f["filename"]] | ||
|
|
||
| if isinstance(config, TokenizeConfig): | ||
| _validate_train_urls([f["filename"] for f in train_file_stats], warn=config.allow_test_in_train) | ||
|
|
||
| if train_file_stats: | ||
| logger.info(f"Found {len(train_file_stats)} training files") | ||
| if validation_file_stats: | ||
| logger.info(f"Found {len(validation_file_stats)} validation files") | ||
|
|
||
| if not train_file_stats and not validation_file_stats: | ||
| raise ValueError("No input files specified. Nothing to do.") | ||
|
rjpower marked this conversation as resolved.
Outdated
|
||
|
|
||
| def local_preprocess_paths(paths: list[str]) -> list[list[str]]: | ||
| """Scan file sizes locally and bundle into groups for distributed processing.""" | ||
| filescan_start = time.monotonic() | ||
| # Sort for deterministic batching, then chunk into groups of 64. | ||
| paths = sorted(paths) | ||
| batched_paths = [paths[i : i + 64] for i in range(0, len(paths), 64)] | ||
| scan_ctx = ZephyrContext( | ||
| max_workers=32, | ||
| resources=ResourceConfig(cpu=1, ram="1g"), | ||
| name="tokenize-filescan", | ||
| ) | ||
| file_stats = scan_ctx.execute( | ||
| Dataset.from_list(batched_paths).flat_map( | ||
| lambda batch: [{"filename": p, "size": fsspec_size(p)} for p in batch] | ||
| ), | ||
| verbose=False, | ||
| ).results | ||
| def local_preprocess_paths(file_stats: list[dict]) -> list[list[str]]: | ||
| """Bundle files into size-balanced groups for distributed processing.""" | ||
| file_stats = sorted(file_stats, key=lambda f: f["filename"]) | ||
| total_input_bytes = sum(f["size"] for f in file_stats) | ||
| if config.num_shards is not None: | ||
| target_group_bytes = _compute_target_group_bytes(total_input_bytes, config.num_shards) | ||
| else: | ||
| target_group_bytes = _compute_target_group_bytes(total_input_bytes, config.max_workers) | ||
| file_groups = list(_bundle_files_by_size(file_stats, target_group_bytes)) | ||
| logger.info( | ||
| f"Grouped {len(paths):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups " | ||
| f"(target {target_group_bytes / 1e9:.2f} GB/group) in {time.monotonic() - filescan_start:.1f}s." | ||
| f"Grouped {len(file_stats):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups " | ||
| f"(target {target_group_bytes / 1e9:.2f} GB/group)." | ||
| ) | ||
| return file_groups | ||
|
|
||
|
|
@@ -451,17 +447,17 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s | |
| ) | ||
|
|
||
| # TODO (rav): both train and val could run at the same time | ||
| if train_paths and not split_already_done("train"): | ||
| train_groups = local_preprocess_paths(train_paths) | ||
| if train_file_stats and not split_already_done("train"): | ||
| train_groups = local_preprocess_paths(train_file_stats) | ||
| ctx = ZephyrContext( | ||
| resources=config.worker_resources, | ||
| max_workers=min(config.max_workers, len(train_groups)), | ||
| name="tokenize-train", | ||
| ) | ||
| run_pipeline(ctx, train_groups, "train") | ||
|
|
||
| if validation_paths and not split_already_done("validation"): | ||
| validation_groups = local_preprocess_paths(validation_paths) | ||
| if validation_file_stats and not split_already_done("validation"): | ||
| validation_groups = local_preprocess_paths(validation_file_stats) | ||
| ctx = ZephyrContext( | ||
| resources=config.worker_resources, | ||
| max_workers=min(config.max_workers, len(validation_groups)), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,6 +132,7 @@ class InputFileSpec: | |
|
|
||
| path: str | ||
| format: Literal["parquet", "jsonl", "vortex", "auto"] = "auto" | ||
| size: int | None = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring not updated.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved off of the filespec |
||
| columns: list[str] | None = None | ||
| row_start: int | None = None | ||
| row_end: int | None = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is plain
.jsonsafe here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can probably remove, probably less safe than useful i guess, too many chances for weird things