diff --git a/experiments/defaults.py b/experiments/defaults.py index db439f5c3e..83483018c6 100644 --- a/experiments/defaults.py +++ b/experiments/defaults.py @@ -206,6 +206,9 @@ def default_tokenize( *, sample_count: int | VersionedValue[int] | None = None, is_validation: bool = False, + levanter_batch_size: int | None = None, + resources: ResourceConfig | None = None, + worker_resources: ResourceConfig | None = None, ) -> ExecutorStep: """ Tokenizes a dataset using the specified tokenizer and Levanter's tokenization infrastructure. @@ -228,6 +231,11 @@ def default_tokenize( An ExecutorStep that represents the tokenized dataset. """ + # Common kwargs for config constructors + extra_kwargs: dict = {} + if worker_resources is not None: + extra_kwargs["worker_resources"] = worker_resources + # sniff out if it's a HuggingFace dataset if isinstance(dataset, HfDatasetSpec): config = HfTokenizeConfig( @@ -237,6 +245,8 @@ def default_tokenize( tokenizer=ensure_versioned(tokenizer), format=format, sample_count=ensure_versioned(sample_count) if sample_count is not None else None, + levanter_batch_size=levanter_batch_size, + **extra_kwargs, ) elif ( isinstance(dataset, str) @@ -250,6 +260,8 @@ def default_tokenize( tokenizer=ensure_versioned(tokenizer), format=format, sample_count=ensure_versioned(sample_count) if sample_count is not None else None, + levanter_batch_size=levanter_batch_size, + **extra_kwargs, ) else: config = TokenizeConfig( @@ -259,6 +271,8 @@ def default_tokenize( tokenizer=ensure_versioned(tokenizer), format=format, sample_count=ensure_versioned(sample_count) if sample_count is not None else None, + levanter_batch_size=levanter_batch_size, + **extra_kwargs, ) return ExecutorStep( @@ -266,7 +280,7 @@ def default_tokenize( description=f"Tokenize raw text using the {tokenizer} tokenizer.", fn=remote( tokenize, - resources=ResourceConfig.with_cpu(cpu=4, ram="16g", disk="10g"), + resources=resources or ResourceConfig.with_cpu(cpu=4, ram="16g", disk="10g"), pip_dependency_groups=["cpu"], env_vars={ "TRANSFORMERS_NO_TORCH": "1", diff --git a/experiments/pretraining_datasets/starcoder2_extras.py b/experiments/pretraining_datasets/starcoder2_extras.py new file mode 100644 index 0000000000..70a82c8a1e --- /dev/null +++ b/experiments/pretraining_datasets/starcoder2_extras.py @@ -0,0 +1,47 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""StarCoder2 data extras: download and tokenize ir_cpp, ir_python, ir_rust, ir_low_resource, documentation.""" + +from experiments.defaults import default_tokenize +from experiments.marin_models import marin_tokenizer +from fray.v2 import ResourceConfig +from levanter.data.text.formats import TextLmDatasetFormat +from marin.datakit.download.starcoder2_extras import ( + SUBSETS, + download_starcoder2_extras_step, +) +from marin.datakit.normalize import normalize_step +from marin.execution.executor import executor_main +from marin.processing.tokenize.data_configs import TokenizerStep + + +def tokenize_starcoder2_extras(*, tokenizer: str = marin_tokenizer) -> list[TokenizerStep]: + """Download, normalize, and tokenize all selected starcoder2data-extras subsets.""" + steps = [] + for subset in SUBSETS: + download = download_starcoder2_extras_step(subset) + normalized = normalize_step( + name=f"normalized/starcoder2_extras/{subset}", + download=download, + text_field="content", + file_extensions=(".parquet",), + ) + # documentation contains a single 64MB OpenJDK record that peaks at ~9GB RSS + # during tokenization; bump memory to 32GB for that subset + doc_resources = ResourceConfig(ram="32g", disk="10g") if subset == "documentation" else None + steps.append( + default_tokenize( + name=f"starcoder2_extras/{subset}", + dataset=normalized.as_executor_step(), + tokenizer=tokenizer, + format=TextLmDatasetFormat(text_key="text"), + levanter_batch_size=128, + worker_resources=doc_resources, + ) + ) + return steps + + +if __name__ == "__main__": + executor_main(steps=tokenize_starcoder2_extras()) diff --git a/lib/marin/src/marin/datakit/download/starcoder2_extras.py b/lib/marin/src/marin/datakit/download/starcoder2_extras.py new file mode 100644 index 0000000000..d86db3c65d --- /dev/null +++ b/lib/marin/src/marin/datakit/download/starcoder2_extras.py @@ -0,0 +1,31 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Download subsets of the bigcode/starcoder2data-extras dataset from HuggingFace. + +Subsets: ir_cpp, ir_python, ir_rust, ir_low_resource, documentation, kaggle. +""" + +from marin.datakit.download.huggingface import download_hf_step +from marin.execution.step_spec import StepSpec + +HF_DATASET_ID = "bigcode/starcoder2data-extras" +HF_REVISION = "1ba0d4f" + +SUBSETS = ["ir_cpp", "ir_python", "ir_rust", "ir_low_resource", "documentation", "kaggle"] + + +def download_starcoder2_extras_step(subset: str) -> StepSpec: + """Download a single subset of the starcoder2data-extras dataset.""" + return download_hf_step( + f"raw/starcoder2_extras/{subset}", + hf_dataset_id=HF_DATASET_ID, + revision=HF_REVISION, + hf_urls_glob=[f"{subset}/*.parquet"], + override_output_path=f"raw/starcoder2_extras-{HF_REVISION}/{subset}", + ) + + +def download_all_starcoder2_extras_steps() -> list[StepSpec]: + """Download all selected subsets of starcoder2data-extras.""" + return [download_starcoder2_extras_step(subset) for subset in SUBSETS] diff --git a/lib/marin/src/marin/datakit/normalize.py b/lib/marin/src/marin/datakit/normalize.py index 2da1eb7e80..ea28431e3f 100644 --- a/lib/marin/src/marin/datakit/normalize.py +++ b/lib/marin/src/marin/datakit/normalize.py @@ -85,14 +85,21 @@ def normalize_record(record: dict[str, Any]) -> dict[str, Any]: def _discover_file_groups( input_path: str, + file_extensions: tuple[str, ...] | None = None, ) -> dict[str, list[str]]: """Walk *input_path* and group data files by their subdirectory. Returns a mapping from relative subdirectory (``""`` for root) to a sorted - list of file paths. Only files with extensions supported by - ``zephyr.readers.load_file`` are included; dotfiles and ``.metrics`` - directories are skipped. + list of file paths. Only files with matching extensions are included; + dotfiles and ``.metrics`` directories are skipped. + + Args: + input_path: Root directory to walk. + file_extensions: Tuple of file extensions to include (e.g. + ``(".parquet",)``). Defaults to all extensions supported by + ``zephyr.readers.load_file``. """ + extensions = file_extensions or SUPPORTED_EXTENSIONS fs, resolved = url_to_fs(input_path) protocol = input_path.split("://")[0] if "://" in input_path else "" @@ -113,7 +120,7 @@ def _full_path(p: str) -> str: for fname in sorted(files): if fname.startswith("."): continue - if not fname.endswith(SUPPORTED_EXTENSIONS): + if not fname.endswith(extensions): continue full = _full_path(os.path.join(root, fname)) groups.setdefault(rel_root, []).append(full) @@ -178,6 +185,7 @@ def normalize_to_parquet( id_field: str = "id", target_partition_bytes: int = 256 * 1024 * 1024, worker_resources: ResourceConfig | None = None, + file_extensions: tuple[str, ...] | None = None, ) -> None: """Normalize raw downloaded data to the datakit standard Parquet format. @@ -200,10 +208,13 @@ def normalize_to_parquet( Defaults to 2 CPU / 16GB RAM / 10GB disk, sized for ``target_partition_bytes`` of 256MB. Scale up when increasing partition size. + file_extensions: Tuple of file extensions to include (e.g. + ``(".parquet",)``). Defaults to all extensions supported by + ``zephyr.readers.load_file``. """ resources = worker_resources or ResourceConfig(cpu=2, ram="16g", disk="10g") - file_groups = _discover_file_groups(input_path) + file_groups = _discover_file_groups(input_path, file_extensions=file_extensions) if not file_groups: raise FileNotFoundError(f"No data files found under {input_path}") @@ -249,6 +260,7 @@ def normalize_step( worker_resources: ResourceConfig | None = None, override_output_path: str | None = None, input_path: str | None = None, + file_extensions: tuple[str, ...] | None = None, ) -> StepSpec: """Create a StepSpec that normalizes downloaded data to Parquet. @@ -263,6 +275,9 @@ def normalize_step( override_output_path: Override the computed output path. input_path: Override the input path. Defaults to ``download.output_path``. Useful when normalizing a subdirectory of the download output. + file_extensions: Tuple of file extensions to include (e.g. + ``(".parquet",)``). Defaults to all extensions supported by + ``zephyr.readers.load_file``. """ resolved_input = input_path or download.output_path @@ -275,6 +290,7 @@ def normalize_step( id_field=id_field, target_partition_bytes=target_partition_bytes, worker_resources=worker_resources, + file_extensions=file_extensions, ), deps=[download], hash_attrs={ @@ -282,6 +298,7 @@ def normalize_step( "id_field": id_field, "target_partition_bytes": target_partition_bytes, "input_path": resolved_input, + "file_extensions": file_extensions, }, override_output_path=override_output_path, ) diff --git a/lib/marin/src/marin/processing/tokenize/tokenize.py b/lib/marin/src/marin/processing/tokenize/tokenize.py index 2e44fc14c5..01d6699917 100644 --- a/lib/marin/src/marin/processing/tokenize/tokenize.py +++ b/lib/marin/src/marin/processing/tokenize/tokenize.py @@ -76,6 +76,10 @@ class TokenizeConfigBase(abc.ABC): this many shards instead of deriving the count from max_workers. This can be useful if you want more shards than max_workers, for example to mitigate the cost of retrying a single shard.""" + levanter_batch_size: int | None = None + """Number of tokenized records to accumulate before flushing to disk. Defaults to 16384. + Lower values reduce peak memory for datasets with large documents.""" + @abc.abstractmethod def as_lm_dataset_source_config( self, actual_output_path: str | InputName | None, *, include_raw_paths=True @@ -398,6 +402,7 @@ def run_pipeline(ctx: ZephyrContext, file_groups: list[list[str]], split_name: s f"{prefix}/part-{{shard:05d}}-of-{{total:05d}}", metadata={}, skip_existing=True, + batch_size=config.levanter_batch_size, ) ) diff --git a/lib/zephyr/src/zephyr/dataset.py b/lib/zephyr/src/zephyr/dataset.py index a936a3f64f..c8c1569788 100644 --- a/lib/zephyr/src/zephyr/dataset.py +++ b/lib/zephyr/src/zephyr/dataset.py @@ -161,6 +161,7 @@ class WriteOp: # Format-specific parameters (only used by relevant writer) levanter_metadata: dict[str, Any] | None = None + levanter_batch_size: int | None = None schema: object | None = None # For parquet (pyarrow.Schema) skip_existing: bool = False # Skip writing if output file already exists @@ -727,6 +728,7 @@ def write_levanter_cache( output_pattern: str | Callable[[int, int], str], metadata: dict[str, Any], skip_existing: bool = False, + batch_size: int | None = None, ) -> Dataset[str]: """Write tokenized records to Levanter cache format. @@ -734,6 +736,10 @@ def write_levanter_cache( in training. Each shard creates a separate cache directory. The output pattern supports substitutions: {shard:05d}, {total:05d}, {basename} or can be a callable that takes (shard_idx, total_shards) and returns the output path. + + Args: + batch_size: Number of records to accumulate before flushing to disk. + Defaults to 16384. Lower values reduce peak memory for large documents. """ return Dataset( self.source, @@ -743,6 +749,7 @@ def write_levanter_cache( _normalize_output_pattern(output_pattern), writer_type="levanter_cache", levanter_metadata=metadata, + levanter_batch_size=batch_size, skip_existing=skip_existing, ), ], diff --git a/lib/zephyr/src/zephyr/plan.py b/lib/zephyr/src/zephyr/plan.py index c6b2842926..24d02907ff 100644 --- a/lib/zephyr/src/zephyr/plan.py +++ b/lib/zephyr/src/zephyr/plan.py @@ -106,6 +106,7 @@ class Write: skip_existing: bool = False # Writer-specific parameters levanter_metadata: dict | None = None + levanter_batch_size: int | None = None schema: Any = None # For parquet @@ -421,6 +422,7 @@ def _fuse_operations(operations: list) -> list[PhysicalStage]: writer_type=op.writer_type, skip_existing=op.skip_existing, levanter_metadata=op.levanter_metadata, + levanter_batch_size=op.levanter_batch_size, schema=op.schema, ) ) @@ -782,7 +784,10 @@ def run_stage( result = write_parquet_file(stream, output_path, schema=op.schema)["path"] elif op.writer_type == "levanter_cache": metadata = op.levanter_metadata if op.levanter_metadata is not None else {} - result = write_levanter_cache(stream, output_path, metadata=metadata)["path"] + kwargs: dict[str, Any] = {"metadata": metadata} + if op.levanter_batch_size is not None: + kwargs["batch_size"] = op.levanter_batch_size + result = write_levanter_cache(stream, output_path, **kwargs)["path"] elif op.writer_type == "binary": result = write_binary_file(stream, output_path)["path"] elif op.writer_type == "vortex": diff --git a/lib/zephyr/src/zephyr/writers.py b/lib/zephyr/src/zephyr/writers.py index b0272e2429..971e6343d5 100644 --- a/lib/zephyr/src/zephyr/writers.py +++ b/lib/zephyr/src/zephyr/writers.py @@ -373,6 +373,7 @@ def write_levanter_cache( output_path: str, *, metadata: dict[str, Any], + batch_size: int = _LEVANTER_BATCH_SIZE, ) -> dict: """Write tokenized records to Levanter cache format. @@ -380,7 +381,11 @@ def write_levanter_cache( records: Tokenized records (iterable of dicts with array values) output_path: Path to output cache directory metadata: Metadata for the cache + batch_size: Number of records to accumulate before flushing to disk. """ + if batch_size < 1: + raise ValueError(f"batch_size must be >= 1, got {batch_size}") + from levanter.store.cache import CacheMetadata, SerialCacheWriter ensure_parent_dir(output_path) @@ -392,7 +397,7 @@ def write_levanter_cache( return {"path": output_path, "count": 0} count = 0 - logger.info("write_levanter_cache: starting write to %s (batch_size=%d)", output_path, _LEVANTER_BATCH_SIZE) + logger.info("write_levanter_cache: starting write to %s (batch_size=%d)", output_path, batch_size) with atomic_rename(output_path) as tmp_path: with SerialCacheWriter(tmp_path, exemplar, shard_name=output_path, metadata=CacheMetadata(metadata)) as writer: @@ -405,7 +410,7 @@ def _drain_batches(batches: Iterable) -> None: threaded.submit([exemplar]) count += 1 counters.increment("zephyr/records_out") - for batch in batchify(record_iter, n=_LEVANTER_BATCH_SIZE): + for batch in batchify(record_iter, n=batch_size): threaded.submit(batch) count += len(batch) counters.increment("zephyr/records_out", len(batch))