Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion experiments/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ def default_tokenize(
*,
sample_count: int | VersionedValue[int] | None = None,
is_validation: bool = False,
levanter_batch_size: int | None = None,
resources: ResourceConfig | None = None,
worker_resources: ResourceConfig | None = None,
) -> ExecutorStep:
"""
Tokenizes a dataset using the specified tokenizer and Levanter's tokenization infrastructure.
Expand All @@ -228,6 +231,11 @@ def default_tokenize(
An ExecutorStep that represents the tokenized dataset.
"""

# Common kwargs for config constructors
extra_kwargs: dict = {}
if worker_resources is not None:
extra_kwargs["worker_resources"] = worker_resources

# sniff out if it's a HuggingFace dataset
if isinstance(dataset, HfDatasetSpec):
config = HfTokenizeConfig(
Expand All @@ -237,6 +245,8 @@ def default_tokenize(
tokenizer=ensure_versioned(tokenizer),
format=format,
sample_count=ensure_versioned(sample_count) if sample_count is not None else None,
levanter_batch_size=levanter_batch_size,
**extra_kwargs,
)
elif (
isinstance(dataset, str)
Expand All @@ -250,6 +260,8 @@ def default_tokenize(
tokenizer=ensure_versioned(tokenizer),
format=format,
sample_count=ensure_versioned(sample_count) if sample_count is not None else None,
levanter_batch_size=levanter_batch_size,
**extra_kwargs,
)
else:
config = TokenizeConfig(
Expand All @@ -259,14 +271,16 @@ def default_tokenize(
tokenizer=ensure_versioned(tokenizer),
format=format,
sample_count=ensure_versioned(sample_count) if sample_count is not None else None,
levanter_batch_size=levanter_batch_size,
**extra_kwargs,
)

return ExecutorStep(
name=os.path.join("tokenized", name),
description=f"Tokenize raw text using the {tokenizer} tokenizer.",
fn=remote(
tokenize,
resources=ResourceConfig.with_cpu(cpu=4, ram="16g", disk="10g"),
resources=resources or ResourceConfig.with_cpu(cpu=4, ram="16g", disk="10g"),
pip_dependency_groups=["cpu"],
env_vars={
"TRANSFORMERS_NO_TORCH": "1",
Expand Down
47 changes: 47 additions & 0 deletions experiments/pretraining_datasets/starcoder2_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""StarCoder2 data extras: download and tokenize ir_cpp, ir_python, ir_rust, ir_low_resource, documentation."""

from experiments.defaults import default_tokenize
from experiments.marin_models import marin_tokenizer
from fray.v2 import ResourceConfig
from levanter.data.text.formats import TextLmDatasetFormat
from marin.datakit.download.starcoder2_extras import (
SUBSETS,
download_starcoder2_extras_step,
)
from marin.datakit.normalize import normalize_step
from marin.execution.executor import executor_main
from marin.processing.tokenize.data_configs import TokenizerStep


def tokenize_starcoder2_extras(*, tokenizer: str = marin_tokenizer) -> list[TokenizerStep]:
"""Download, normalize, and tokenize all selected starcoder2data-extras subsets."""
steps = []
for subset in SUBSETS:
download = download_starcoder2_extras_step(subset)
normalized = normalize_step(
name=f"normalized/starcoder2_extras/{subset}",
download=download,
text_field="content",
file_extensions=(".parquet",),
)
# documentation contains a single 64MB OpenJDK record that peaks at ~9GB RSS
# during tokenization; bump memory to 32GB for that subset
doc_resources = ResourceConfig(ram="32g", disk="10g") if subset == "documentation" else None
steps.append(
default_tokenize(
name=f"starcoder2_extras/{subset}",
dataset=normalized.as_executor_step(),
tokenizer=tokenizer,
format=TextLmDatasetFormat(text_key="text"),
levanter_batch_size=128,
worker_resources=doc_resources,
)
)
return steps


if __name__ == "__main__":
executor_main(steps=tokenize_starcoder2_extras())
31 changes: 31 additions & 0 deletions lib/marin/src/marin/datakit/download/starcoder2_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Download subsets of the bigcode/starcoder2data-extras dataset from HuggingFace.

Subsets: ir_cpp, ir_python, ir_rust, ir_low_resource, documentation, kaggle.
"""

from marin.datakit.download.huggingface import download_hf_step
from marin.execution.step_spec import StepSpec

HF_DATASET_ID = "bigcode/starcoder2data-extras"
HF_REVISION = "1ba0d4f"

SUBSETS = ["ir_cpp", "ir_python", "ir_rust", "ir_low_resource", "documentation", "kaggle"]


def download_starcoder2_extras_step(subset: str) -> StepSpec:
"""Download a single subset of the starcoder2data-extras dataset."""
return download_hf_step(
f"raw/starcoder2_extras/{subset}",
hf_dataset_id=HF_DATASET_ID,
revision=HF_REVISION,
hf_urls_glob=[f"{subset}/*.parquet"],
override_output_path=f"raw/starcoder2_extras-{HF_REVISION}/{subset}",
)


def download_all_starcoder2_extras_steps() -> list[StepSpec]:
"""Download all selected subsets of starcoder2data-extras."""
return [download_starcoder2_extras_step(subset) for subset in SUBSETS]
27 changes: 22 additions & 5 deletions lib/marin/src/marin/datakit/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,21 @@ def normalize_record(record: dict[str, Any]) -> dict[str, Any]:

def _discover_file_groups(
input_path: str,
file_extensions: tuple[str, ...] | None = None,
) -> dict[str, list[str]]:
"""Walk *input_path* and group data files by their subdirectory.

Returns a mapping from relative subdirectory (``""`` for root) to a sorted
list of file paths. Only files with extensions supported by
``zephyr.readers.load_file`` are included; dotfiles and ``.metrics``
directories are skipped.
list of file paths. Only files with matching extensions are included;
dotfiles and ``.metrics`` directories are skipped.

Args:
input_path: Root directory to walk.
file_extensions: Tuple of file extensions to include (e.g.
``(".parquet",)``). Defaults to all extensions supported by
``zephyr.readers.load_file``.
"""
extensions = file_extensions or SUPPORTED_EXTENSIONS
fs, resolved = url_to_fs(input_path)
protocol = input_path.split("://")[0] if "://" in input_path else ""

Expand All @@ -113,7 +120,7 @@ def _full_path(p: str) -> str:
for fname in sorted(files):
if fname.startswith("."):
continue
if not fname.endswith(SUPPORTED_EXTENSIONS):
if not fname.endswith(extensions):
continue
full = _full_path(os.path.join(root, fname))
groups.setdefault(rel_root, []).append(full)
Expand Down Expand Up @@ -178,6 +185,7 @@ def normalize_to_parquet(
id_field: str = "id",
target_partition_bytes: int = 256 * 1024 * 1024,
worker_resources: ResourceConfig | None = None,
file_extensions: tuple[str, ...] | None = None,
) -> None:
"""Normalize raw downloaded data to the datakit standard Parquet format.

Expand All @@ -200,10 +208,13 @@ def normalize_to_parquet(
Defaults to 2 CPU / 16GB RAM / 10GB disk, sized for
``target_partition_bytes`` of 256MB. Scale up when increasing
partition size.
file_extensions: Tuple of file extensions to include (e.g.
``(".parquet",)``). Defaults to all extensions supported by
``zephyr.readers.load_file``.
"""
resources = worker_resources or ResourceConfig(cpu=2, ram="16g", disk="10g")

file_groups = _discover_file_groups(input_path)
file_groups = _discover_file_groups(input_path, file_extensions=file_extensions)
if not file_groups:
raise FileNotFoundError(f"No data files found under {input_path}")

Expand Down Expand Up @@ -249,6 +260,7 @@ def normalize_step(
worker_resources: ResourceConfig | None = None,
override_output_path: str | None = None,
input_path: str | None = None,
file_extensions: tuple[str, ...] | None = None,
) -> StepSpec:
"""Create a StepSpec that normalizes downloaded data to Parquet.

Expand All @@ -263,6 +275,9 @@ def normalize_step(
override_output_path: Override the computed output path.
input_path: Override the input path. Defaults to ``download.output_path``.
Useful when normalizing a subdirectory of the download output.
file_extensions: Tuple of file extensions to include (e.g.
``(".parquet",)``). Defaults to all extensions supported by
``zephyr.readers.load_file``.
"""
resolved_input = input_path or download.output_path

Expand All @@ -275,13 +290,15 @@ def normalize_step(
id_field=id_field,
target_partition_bytes=target_partition_bytes,
worker_resources=worker_resources,
file_extensions=file_extensions,
),
deps=[download],
hash_attrs={
"text_field": text_field,
"id_field": id_field,
"target_partition_bytes": target_partition_bytes,
"input_path": resolved_input,
"file_extensions": file_extensions,
},
override_output_path=override_output_path,
)
5 changes: 5 additions & 0 deletions lib/marin/src/marin/processing/tokenize/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class TokenizeConfigBase(abc.ABC):
this many shards instead of deriving the count from max_workers. This can be useful if you want
more shards than max_workers, for example to mitigate the cost of retrying a single shard."""

levanter_batch_size: int | None = None
"""Number of tokenized records to accumulate before flushing to disk. Defaults to 16384.
Lower values reduce peak memory for datasets with large documents."""

@abc.abstractmethod
def as_lm_dataset_source_config(
self, actual_output_path: str | InputName | None, *, include_raw_paths=True
Expand Down Expand Up @@ -398,6 +402,7 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s
f"{prefix}/part-{{shard:05d}}-of-{{total:05d}}",
metadata={},
skip_existing=True,
batch_size=config.levanter_batch_size,
)
)

Expand Down
7 changes: 7 additions & 0 deletions lib/zephyr/src/zephyr/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class WriteOp:

# Format-specific parameters (only used by relevant writer)
levanter_metadata: dict[str, Any] | None = None
levanter_batch_size: int | None = None
schema: object | None = None # For parquet (pyarrow.Schema)
skip_existing: bool = False # Skip writing if output file already exists

Expand Down Expand Up @@ -727,13 +728,18 @@ def write_levanter_cache(
output_pattern: str | Callable[[int, int], str],
metadata: dict[str, Any],
skip_existing: bool = False,
batch_size: int | None = None,
) -> Dataset[str]:
"""Write tokenized records to Levanter cache format.

Writes records to Levanter's TreeStore/JaggedArrayStore format for use
in training. Each shard creates a separate cache directory.
The output pattern supports substitutions: {shard:05d}, {total:05d}, {basename}
or can be a callable that takes (shard_idx, total_shards) and returns the output path.

Args:
batch_size: Number of records to accumulate before flushing to disk.
Defaults to 16384. Lower values reduce peak memory for large documents.
"""
return Dataset(
self.source,
Expand All @@ -743,6 +749,7 @@ def write_levanter_cache(
_normalize_output_pattern(output_pattern),
writer_type="levanter_cache",
levanter_metadata=metadata,
levanter_batch_size=batch_size,
skip_existing=skip_existing,
),
],
Expand Down
7 changes: 6 additions & 1 deletion lib/zephyr/src/zephyr/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Write:
skip_existing: bool = False
# Writer-specific parameters
levanter_metadata: dict | None = None
levanter_batch_size: int | None = None
schema: Any = None # For parquet


Expand Down Expand Up @@ -421,6 +422,7 @@ def _fuse_operations(operations: list) -> list[PhysicalStage]:
writer_type=op.writer_type,
skip_existing=op.skip_existing,
levanter_metadata=op.levanter_metadata,
levanter_batch_size=op.levanter_batch_size,
schema=op.schema,
)
)
Expand Down Expand Up @@ -782,7 +784,10 @@ def run_stage(
result = write_parquet_file(stream, output_path, schema=op.schema)["path"]
elif op.writer_type == "levanter_cache":
metadata = op.levanter_metadata if op.levanter_metadata is not None else {}
result = write_levanter_cache(stream, output_path, metadata=metadata)["path"]
kwargs: dict[str, Any] = {"metadata": metadata}
if op.levanter_batch_size is not None:
kwargs["batch_size"] = op.levanter_batch_size
result = write_levanter_cache(stream, output_path, **kwargs)["path"]
elif op.writer_type == "binary":
result = write_binary_file(stream, output_path)["path"]
elif op.writer_type == "vortex":
Expand Down
9 changes: 7 additions & 2 deletions lib/zephyr/src/zephyr/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,19 @@ def write_levanter_cache(
output_path: str,
*,
metadata: dict[str, Any],
batch_size: int = _LEVANTER_BATCH_SIZE,
) -> dict:
"""Write tokenized records to Levanter cache format.

Args:
records: Tokenized records (iterable of dicts with array values)
output_path: Path to output cache directory
metadata: Metadata for the cache
batch_size: Number of records to accumulate before flushing to disk.
"""
if batch_size < 1:
raise ValueError(f"batch_size must be >= 1, got {batch_size}")

from levanter.store.cache import CacheMetadata, SerialCacheWriter

ensure_parent_dir(output_path)
Expand All @@ -392,7 +397,7 @@ def write_levanter_cache(
return {"path": output_path, "count": 0}

count = 0
logger.info("write_levanter_cache: starting write to %s (batch_size=%d)", output_path, _LEVANTER_BATCH_SIZE)
logger.info("write_levanter_cache: starting write to %s (batch_size=%d)", output_path, batch_size)

with atomic_rename(output_path) as tmp_path:
with SerialCacheWriter(tmp_path, exemplar, shard_name=output_path, metadata=CacheMetadata(metadata)) as writer:
Expand All @@ -405,7 +410,7 @@ def _drain_batches(batches: Iterable) -> None:
threaded.submit([exemplar])
count += 1
counters.increment("zephyr/records_out")
for batch in batchify(record_iter, n=_LEVANTER_BATCH_SIZE):
for batch in batchify(record_iter, n=batch_size):
Comment thread
ravwojdyla marked this conversation as resolved.
threaded.submit(batch)
count += len(batch)
counters.increment("zephyr/records_out", len(batch))
Expand Down
Loading