Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 66 additions & 70 deletions lib/marin/src/marin/processing/tokenize/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import time
from collections.abc import Iterator, Sequence

import braceexpand
import draccus
from rigging.filesystem import open_url
import fsspec
from rigging.filesystem import open_url, url_to_fs
from datasets import load_dataset_builder
from fray.v2 import ResourceConfig
from levanter.data.text import (
Expand All @@ -34,8 +36,8 @@
from zephyr import Dataset, ZephyrContext, zephyr_worker_ctx
from zephyr.readers import load_file

from marin.execution.executor import ExecutorStep, InputName, VersionedValue
from marin.utils import fsspec_exists, fsspec_glob, fsspec_isdir, fsspec_size
from marin.execution.executor import InputName, VersionedValue
from marin.utils import fsspec_exists, fsspec_isdir
from rigging.log_setup import configure_logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -205,46 +207,47 @@ def _validate_train_urls(train_paths: list[str | InputName], warn):
)


def _get_files_by_extensions(input_paths: list[str], extensions: list[str]) -> list[str]:
"""
Get a list of all filepaths with the specified extension from the input paths.
"""
output_paths = []
for path in input_paths:
assert path != "/"
if path.endswith("/") or fsspec_isdir(path):
logger.info(f"Getting all {extensions} files in {path}")
for ex in extensions:
output_paths.extend(fsspec_glob(os.path.join(path, f"**/*.{ex}")))
else:
output_paths.extend(fsspec_glob(path))
_TOKENIZE_EXTENSIONS = ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet", "json"]

return output_paths

def _glob_with_sizes(patterns: list[str]) -> list[dict]:
"""Glob patterns and return [{"filename": path, "size": bytes}].

def _get_filepaths_to_tokenize(input_paths: list[str]) -> list[str]:
Uses fsspec glob(detail=True) which returns file metadata from the same
list-objects API call — no per-file stat RPCs needed. Works for gs://, hf://, s3://, local.
"""
Get all file paths to tokenize from the input paths.
Handles json/jsonl.{gz,zst,zstd}, and parquet.
results: list[dict] = []
for pattern in patterns:
pattern = re.sub(r"(?<!:)//+", "/", pattern)
fs, _ = url_to_fs(pattern)
protocol = fsspec.core.split_protocol(pattern)[0]
for expanded in braceexpand.braceexpand(pattern):
detail = fs.glob(expanded, detail=True)
for path, info in detail.items():
full = f"{protocol}://{path}" if protocol else path
results.append({"filename": full, "size": info.get("size", 0)})
return results


def _expand_tokenize_paths(input_paths: list[str]) -> list[str]:
"""Expand input paths into glob patterns for tokenizable file types.

Directories get expanded to recursive globs for each supported extension.
Concrete paths/patterns pass through unchanged.
"""
if isinstance(input_paths, VersionedValue):
input_paths = input_paths.value

if len(input_paths) == 0:
return []
elif any(isinstance(x, InputName | ExecutorStep) for x in input_paths):
return input_paths

out = _get_files_by_extensions(input_paths, ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet", "json"])
out = [x for x in out if "provenance.json" not in x]

if not len(out):
raise ValueError(
f"No valid jsonl or parquet files found in {input_paths}. "
"Please provide a path to a directory containing jsonl or parquet files."
)

return out
patterns: list[str] = []
for path in input_paths:
assert path != "/"
if path.endswith("/") or fsspec_isdir(path):
logger.info(f"Getting all {_TOKENIZE_EXTENSIONS} files in {path}")
for ex in _TOKENIZE_EXTENSIONS:
patterns.append(os.path.join(path, f"**/*.{ex}"))
else:
patterns.append(path)
return patterns


def _bundle_files_by_size(file_infos, max_bytes: int):
Expand Down Expand Up @@ -312,10 +315,8 @@ def tokenize(config: TokenizeConfigBase):
"""

if isinstance(config, TokenizeConfig):
train_paths = _get_filepaths_to_tokenize(config.train_paths) if config.train_paths else []
validation_paths = _get_filepaths_to_tokenize(config.validation_paths) if config.validation_paths else []
# Validate expanded paths to catch validation/test files that were inside directories
_validate_train_urls(train_paths, warn=config.allow_test_in_train)
train_patterns = _expand_tokenize_paths(config.train_paths) if config.train_paths else []
validation_patterns = _expand_tokenize_paths(config.validation_paths) if config.validation_paths else []
elif isinstance(config, HfTokenizeConfig):
logger.info(f"Loading dataset metadata for {config.id}" + (f" (config: {config.name})" if config.name else ""))

Expand All @@ -328,45 +329,40 @@ def tokenize(config: TokenizeConfigBase):
"This might be a dataset that requires custom loading logic."
)

train_paths = data_files.get("train", [])
validation_paths = data_files.get("validation", data_files.get("test", []))

if train_paths:
logger.info(f"Found {len(train_paths)} training files in {config.id}")
if validation_paths:
logger.info(f"Found {len(validation_paths)} validation files in {config.id}")
train_patterns = list(data_files.get("train", []))
validation_patterns = list(data_files.get("validation", data_files.get("test", [])))
else:
raise ValueError(f"Unknown config type: {type(config)}")

if not train_paths and not validation_paths:
# Resolve patterns → concrete files with sizes (single list-objects call per pattern)
train_file_stats = _glob_with_sizes(train_patterns)
train_file_stats = [f for f in train_file_stats if "provenance.json" not in f["filename"]]
validation_file_stats = _glob_with_sizes(validation_patterns)
validation_file_stats = [f for f in validation_file_stats if "provenance.json" not in f["filename"]]

if isinstance(config, TokenizeConfig):
_validate_train_urls([f["filename"] for f in train_file_stats], warn=config.allow_test_in_train)

if train_file_stats:
logger.info(f"Found {len(train_file_stats)} training files")
if validation_file_stats:
logger.info(f"Found {len(validation_file_stats)} validation files")

if not train_file_stats and not validation_file_stats:
raise ValueError("No input files specified. Nothing to do.")

def local_preprocess_paths(paths: list[str]) -> list[list[str]]:
"""Scan file sizes locally and bundle into groups for distributed processing."""
filescan_start = time.monotonic()
# Sort for deterministic batching, then chunk into groups of 64.
paths = sorted(paths)
batched_paths = [paths[i : i + 64] for i in range(0, len(paths), 64)]
scan_ctx = ZephyrContext(
max_workers=32,
resources=ResourceConfig(cpu=1, ram="1g"),
name="tokenize-filescan",
)
file_stats = scan_ctx.execute(
Dataset.from_list(batched_paths).flat_map(
lambda batch: [{"filename": p, "size": fsspec_size(p)} for p in batch]
),
verbose=False,
).results
def local_preprocess_paths(file_stats: list[dict]) -> list[list[str]]:
"""Bundle files into size-balanced groups for distributed processing."""
file_stats = sorted(file_stats, key=lambda f: f["filename"])
total_input_bytes = sum(f["size"] for f in file_stats)
if config.num_shards is not None:
target_group_bytes = _compute_target_group_bytes(total_input_bytes, config.num_shards)
else:
target_group_bytes = _compute_target_group_bytes(total_input_bytes, config.max_workers)
file_groups = list(_bundle_files_by_size(file_stats, target_group_bytes))
logger.info(
f"Grouped {len(paths):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups "
f"(target {target_group_bytes / 1e9:.2f} GB/group) in {time.monotonic() - filescan_start:.1f}s."
f"Grouped {len(file_stats):,} files ({total_input_bytes / 1e9:.2f} GB) into {len(file_groups):,} groups "
f"(target {target_group_bytes / 1e9:.2f} GB/group)."
)
return file_groups

Expand Down Expand Up @@ -451,17 +447,17 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s
)

# TODO (rav): both train and val could run at the same time
if train_paths and not split_already_done("train"):
train_groups = local_preprocess_paths(train_paths)
if train_file_stats and not split_already_done("train"):
train_groups = local_preprocess_paths(train_file_stats)
ctx = ZephyrContext(
resources=config.worker_resources,
max_workers=min(config.max_workers, len(train_groups)),
name="tokenize-train",
)
run_pipeline(ctx, train_groups, "train")

if validation_paths and not split_already_done("validation"):
validation_groups = local_preprocess_paths(validation_paths)
if validation_file_stats and not split_already_done("validation"):
validation_groups = local_preprocess_paths(validation_file_stats)
ctx = ZephyrContext(
resources=config.worker_resources,
max_workers=min(config.max_workers, len(validation_groups)),
Expand Down
65 changes: 46 additions & 19 deletions lib/zephyr/src/zephyr/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,51 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class GlobSource:
"""Lazy file source resolved at plan time via a bulk list-objects call.

Stores the glob pattern and defers expansion to compute_plan(), where
fsspec glob(detail=True) returns paths and sizes in a single RPC.
"""

pattern: str
empty_glob_ok: bool = False


@dataclass(frozen=True)
class FileEntry:
"""A file path with its size from the bulk listing."""

path: str
size: int


def resolve_glob(source: GlobSource) -> list[FileEntry]:
"""Expand a GlobSource into FileEntry objects with sizes.

Uses fsspec glob(detail=True) which returns file metadata from the same
list-objects API call — no extra per-file stat RPCs.
"""
pattern = re.sub(r"(?<!:)//+", "/", source.pattern)

fs, _ = url_to_fs(pattern)
protocol = fsspec.core.split_protocol(pattern)[0]

entries: list[FileEntry] = []
for expanded in braceexpand(pattern):
detail = fs.glob(expanded, detail=True)
for path, info in detail.items():
full = f"{protocol}://{path}" if protocol else path
entries.append(FileEntry(path=full, size=info.get("size", 0)))
entries.sort(key=lambda e: e.path)

if not entries and not source.empty_glob_ok:
raise FileNotFoundError(f"No files found matching pattern: {source.pattern}")

return entries


@dataclass(frozen=True)
class ShardInfo:
"""Metadata about the current shard passed to map_shard functions.
Expand Down Expand Up @@ -362,25 +407,7 @@ def from_files(
... )
>>> output_files = ctx.execute(ds).results
"""
# Normalize double slashes while preserving protocol (e.g., gs://, s3://, http://)
pattern = re.sub(r"(?<!:)//+", "/", pattern)

fs, _ = url_to_fs(pattern)
protocol = fsspec.core.split_protocol(pattern)[0]

files = []
for expanded in braceexpand(pattern):
for f in fs.glob(expanded):
if protocol:
files.append(f"{protocol}://{f}")
else:
files.append(f)
files = sorted(files)

if len(files) == 0 and not empty_glob_ok:
raise FileNotFoundError(f"No files found matching pattern: {pattern}")

return Dataset.from_list(files)
return Dataset(GlobSource(pattern, empty_glob_ok))

def map(self, fn: Callable[[T], R]) -> Dataset[R]:
"""Map a function over the dataset.
Expand Down
35 changes: 27 additions & 8 deletions lib/zephyr/src/zephyr/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@

from zephyr.dataset import (
Dataset,
FileEntry,
FilterOp,
FlatMapOp,
GlobSource,
GroupByOp,
JoinOp,
LoadFileOp,
Expand All @@ -43,6 +45,7 @@
TakePerShardOp,
WindowOp,
WriteOp,
resolve_glob,
)
from zephyr.expr import Expr
from zephyr.readers import InputFileSpec
Expand Down Expand Up @@ -470,14 +473,14 @@ def _fuse_operations(operations: list) -> list[PhysicalStage]:


def _compute_file_pushdown(
paths: list[str],
files: list[FileEntry],
load_op: LoadFileOp,
operations: list,
) -> tuple[list[SourceItem], list]:
"""Create source items for file pipeline with pushdown optimizations applied.

Args:
paths: List of file paths to load
files: List of FileEntry objects (path + size from bulk listing)
load_op: The LoadFileOp specifying format and default columns
operations: Full operations list (first op is LoadFileOp)

Expand Down Expand Up @@ -510,13 +513,14 @@ def _compute_file_pushdown(
SourceItem(
shard_idx=i,
data=InputFileSpec(
path=path,
path=entry.path,
format=load_op.format,
size=entry.size,
columns=select_columns,
filter_expr=filter_expr,
),
)
for i, path in enumerate(paths)
for i, entry in enumerate(files)
]

# Build final operations list: LoadFileOp + remaining ops
Expand All @@ -528,15 +532,30 @@ def _compute_file_pushdown(
def compute_plan(dataset: Dataset) -> PhysicalPlan:
"""Compute physical execution plan from logical dataset."""
operations = list(dataset.operations)

if operations and isinstance(operations[0], LoadFileOp):
source = dataset.source

# Resolve lazy glob sources into concrete FileEntry objects (with sizes).
if isinstance(source, GlobSource):
file_entries = resolve_glob(source)
if operations and isinstance(operations[0], LoadFileOp):
source_items, operations = _compute_file_pushdown(
file_entries,
operations[0],
operations[1:],
)
else:
# from_files() without load_file() — source items are plain paths
source_items = [SourceItem(shard_idx=i, data=entry.path) for i, entry in enumerate(file_entries)]
elif operations and isinstance(operations[0], LoadFileOp):
# Non-glob source (e.g. from_list of paths) — wrap as FileEntry without sizes
entries = [FileEntry(path=p, size=0) for p in source]
source_items, operations = _compute_file_pushdown(
list(dataset.source),
entries,
operations[0],
operations[1:],
)
else:
source_list = list(dataset.source)
source_list = list(source)
source_items = [SourceItem(shard_idx=i, data=item) for i, item in enumerate(source_list)]

stages = _fuse_operations(operations)
Expand Down
1 change: 1 addition & 0 deletions lib/zephyr/src/zephyr/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class InputFileSpec:

path: str
format: Literal["parquet", "jsonl", "vortex", "auto"] = "auto"
size: int | None = None
columns: list[str] | None = None
row_start: int | None = None
row_end: int | None = None
Expand Down
Loading
Loading