diff --git a/lib/marin/src/marin/datakit/normalize.py b/lib/marin/src/marin/datakit/normalize.py index 2da1eb7e80..7c726c2138 100644 --- a/lib/marin/src/marin/datakit/normalize.py +++ b/lib/marin/src/marin/datakit/normalize.py @@ -16,6 +16,7 @@ import logging import os +import re from collections.abc import Callable, Iterator from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -24,11 +25,20 @@ from rigging.filesystem import url_to_fs from marin.execution.step_spec import StepSpec from fray.v2 import ResourceConfig -from zephyr import Dataset, ZephyrContext +from zephyr import Dataset, ZephyrContext, counters from zephyr.readers import SUPPORTED_EXTENSIONS, load_file logger = logging.getLogger(__name__) +# Default cap on the longest consecutive whitespace run in a document. +# Runs exceeding this are compacted to this length at normalization time. +# Pathologically long whitespace runs (e.g. multi-MB runs from broken +# HTML→text extraction, cf. #4588) can OOM downstream tokenization. +DEFAULT_MAX_WHITESPACE_RUN_CHARS = 128 + +# Counter name for documents that had whitespace runs compacted. +COMPACTED_WHITESPACE_COUNTER = "datakit_normalize_compacted_whitespace" + def generate_id(text: str) -> str: """Generate a deterministic document ID from text content. @@ -134,12 +144,34 @@ def _compute_total_bytes(file_paths: list[str]) -> int: return total +def _make_whitespace_compactor(max_whitespace_run_chars: int) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Return a map function that compacts consecutive whitespace runs exceeding the limit. + + Any run of whitespace longer than *max_whitespace_run_chars* is truncated to + that length (preserving the original whitespace characters). Affected records + are counted via the ``COMPACTED_WHITESPACE_COUNTER`` Zephyr counter, and the + ``id`` is recomputed to reflect the new text. + """ + pattern = re.compile(r"\s{" + str(max_whitespace_run_chars + 1) + r",}") + + def compact(record: dict[str, Any]) -> dict[str, Any]: + text = record["text"] + compacted = pattern.sub(lambda m: m.group(0)[:max_whitespace_run_chars], text) + if len(compacted) != len(text): + counters.increment(COMPACTED_WHITESPACE_COUNTER) + record = {**record, "text": compacted, "id": generate_id(compacted)} + return record + + return compact + + def _build_pipeline( files: list[str], output_dir: str, num_shards: int, text_field: str, id_field: str | None, + max_whitespace_run_chars: int, ) -> Dataset: """Build a single Zephyr pipeline for one subdirectory.""" normalize_record = _make_normalize_fn(text_field, id_field) @@ -153,20 +185,16 @@ def dedup_and_sort(_key: int, items: Iterator[dict[str, Any]]) -> Iterator[dict[ prev_id = rid yield record - return ( - Dataset.from_list(files) - .flat_map(load_file) - .map(normalize_record) - .group_by( - key=lambda r: int(r["id"], 16) % num_shards, - reducer=dedup_and_sort, - sort_by=lambda r: r["id"], - num_output_shards=num_shards, - ) - .write_parquet( - f"{output_dir}/part-{{shard:05d}}-of-{{total:05d}}.parquet", - skip_existing=True, - ) + pipeline = Dataset.from_list(files).flat_map(load_file).map(normalize_record) + pipeline = pipeline.map(_make_whitespace_compactor(max_whitespace_run_chars)) + return pipeline.group_by( + key=lambda r: int(r["id"], 16) % num_shards, + reducer=dedup_and_sort, + sort_by=lambda r: r["id"], + num_output_shards=num_shards, + ).write_parquet( + f"{output_dir}/part-{{shard:05d}}-of-{{total:05d}}.parquet", + skip_existing=True, ) @@ -177,6 +205,7 @@ def normalize_to_parquet( text_field: str = "text", id_field: str = "id", target_partition_bytes: int = 256 * 1024 * 1024, + max_whitespace_run_chars: int = DEFAULT_MAX_WHITESPACE_RUN_CHARS, worker_resources: ResourceConfig | None = None, ) -> None: """Normalize raw downloaded data to the datakit standard Parquet format. @@ -196,6 +225,12 @@ def normalize_to_parquet( silently skipped. target_partition_bytes: Target size in bytes per output partition. Used to compute the number of output shards per subdirectory. + max_whitespace_run_chars: Compact any consecutive whitespace run + longer than this many characters down to this length. + Pathologically long whitespace runs (e.g. multi-MB runs from + broken HTML→text extraction, cf. #4588) can OOM downstream + tokenization. Affected records are counted via the + ``datakit_normalize_compacted_whitespace`` Zephyr counter. worker_resources: Per-worker resource request for the Zephyr pipeline. Defaults to 2 CPU / 16GB RAM / 10GB disk, sized for ``target_partition_bytes`` of 256MB. Scale up when increasing @@ -223,7 +258,7 @@ def _run_subdir(subdir: str, files: list[str]) -> None: num_shards, ) - pipeline = _build_pipeline(files, output_dir, num_shards, text_field, id_field) + pipeline = _build_pipeline(files, output_dir, num_shards, text_field, id_field, max_whitespace_run_chars) ctx = ZephyrContext( name=f"normalize-{subdir.replace('/', '-') if subdir else 'all'}", resources=resources, @@ -246,6 +281,7 @@ def normalize_step( text_field: str = "text", id_field: str = "id", target_partition_bytes: int = 256 * 1024 * 1024, + max_whitespace_run_chars: int = DEFAULT_MAX_WHITESPACE_RUN_CHARS, worker_resources: ResourceConfig | None = None, override_output_path: str | None = None, input_path: str | None = None, @@ -274,6 +310,7 @@ def normalize_step( text_field=text_field, id_field=id_field, target_partition_bytes=target_partition_bytes, + max_whitespace_run_chars=max_whitespace_run_chars, worker_resources=worker_resources, ), deps=[download], @@ -281,6 +318,7 @@ def normalize_step( "text_field": text_field, "id_field": id_field, "target_partition_bytes": target_partition_bytes, + "max_whitespace_run_chars": max_whitespace_run_chars, "input_path": resolved_input, }, override_output_path=override_output_path, diff --git a/tests/datakit/test_normalize.py b/tests/datakit/test_normalize.py index 9a41e504d9..2407fea8a3 100644 --- a/tests/datakit/test_normalize.py +++ b/tests/datakit/test_normalize.py @@ -182,6 +182,35 @@ def test_skip_existing_idempotent(tmp_path: Path, write_jsonl_gz): assert parquet_files[0].stat().st_mtime == mtime_first +def test_whitespace_compaction(tmp_path: Path, write_jsonl_gz): + """Long whitespace runs are compacted, not dropped. Content is preserved.""" + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + + records = [ + {"id": "normal", "text": "Hello world"}, + {"id": "pathological", "text": "before" + " " * 500 + "after"}, + {"id": "also_normal", "text": "short spaces are fine"}, + ] + write_jsonl_gz(input_dir / "data.jsonl.gz", records) + + normalize_to_parquet( + input_path=str(input_dir), + output_path=str(output_dir), + max_whitespace_run_chars=100, + ) + + results = _read_all_parquet(output_dir) + # All three records survive — the pathological one is compacted, not dropped + assert len(results) == 3 + by_source = {r["source_id"]: r for r in results} + assert by_source["pathological"]["text"] == "before" + " " * 100 + "after" + # id is recomputed from the compacted text + assert by_source["pathological"]["id"] == generate_id("before" + " " * 100 + "after") + # Normal docs are untouched + assert by_source["normal"]["text"] == "Hello world" + + def test_no_input_files_raises(tmp_path: Path): input_dir = tmp_path / "input" input_dir.mkdir()