Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 54 additions & 16 deletions lib/marin/src/marin/datakit/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import os
import re
from collections.abc import Callable, Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
Expand All @@ -24,11 +25,20 @@
from rigging.filesystem import url_to_fs
from marin.execution.step_spec import StepSpec
from fray.v2 import ResourceConfig
from zephyr import Dataset, ZephyrContext
from zephyr import Dataset, ZephyrContext, counters
from zephyr.readers import SUPPORTED_EXTENSIONS, load_file

logger = logging.getLogger(__name__)

# Default cap on the longest consecutive whitespace run in a document.
# Runs exceeding this are compacted to this length at normalization time.
# Pathologically long whitespace runs (e.g. multi-MB runs from broken
# HTML→text extraction, cf. #4588) can OOM downstream tokenization.
DEFAULT_MAX_WHITESPACE_RUN_CHARS = 128

# Counter name for documents that had whitespace runs compacted.
COMPACTED_WHITESPACE_COUNTER = "datakit_normalize_compacted_whitespace"


def generate_id(text: str) -> str:
"""Generate a deterministic document ID from text content.
Expand Down Expand Up @@ -134,12 +144,34 @@ def _compute_total_bytes(file_paths: list[str]) -> int:
return total


def _make_whitespace_compactor(max_whitespace_run_chars: int) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Return a map function that compacts consecutive whitespace runs exceeding the limit.

Any run of whitespace longer than *max_whitespace_run_chars* is truncated to
that length (preserving the original whitespace characters). Affected records
are counted via the ``COMPACTED_WHITESPACE_COUNTER`` Zephyr counter, and the
``id`` is recomputed to reflect the new text.
"""
pattern = re.compile(r"\s{" + str(max_whitespace_run_chars + 1) + r",}")

def compact(record: dict[str, Any]) -> dict[str, Any]:
text = record["text"]
compacted = pattern.sub(lambda m: m.group(0)[:max_whitespace_run_chars], text)
if len(compacted) != len(text):
counters.increment(COMPACTED_WHITESPACE_COUNTER)
record = {**record, "text": compacted, "id": generate_id(compacted)}
return record

return compact


def _build_pipeline(
files: list[str],
output_dir: str,
num_shards: int,
text_field: str,
id_field: str | None,
max_whitespace_run_chars: int,
) -> Dataset:
"""Build a single Zephyr pipeline for one subdirectory."""
normalize_record = _make_normalize_fn(text_field, id_field)
Expand All @@ -153,20 +185,16 @@ def dedup_and_sort(_key: int, items: Iterator[dict[str, Any]]) -> Iterator[dict[
prev_id = rid
yield record

return (
Dataset.from_list(files)
.flat_map(load_file)
.map(normalize_record)
.group_by(
key=lambda r: int(r["id"], 16) % num_shards,
reducer=dedup_and_sort,
sort_by=lambda r: r["id"],
num_output_shards=num_shards,
)
.write_parquet(
f"{output_dir}/part-{{shard:05d}}-of-{{total:05d}}.parquet",
skip_existing=True,
)
pipeline = Dataset.from_list(files).flat_map(load_file).map(normalize_record)
pipeline = pipeline.map(_make_whitespace_compactor(max_whitespace_run_chars))
return pipeline.group_by(
key=lambda r: int(r["id"], 16) % num_shards,
reducer=dedup_and_sort,
sort_by=lambda r: r["id"],
num_output_shards=num_shards,
).write_parquet(
f"{output_dir}/part-{{shard:05d}}-of-{{total:05d}}.parquet",
skip_existing=True,
)


Expand All @@ -177,6 +205,7 @@ def normalize_to_parquet(
text_field: str = "text",
id_field: str = "id",
target_partition_bytes: int = 256 * 1024 * 1024,
max_whitespace_run_chars: int = DEFAULT_MAX_WHITESPACE_RUN_CHARS,
worker_resources: ResourceConfig | None = None,
) -> None:
"""Normalize raw downloaded data to the datakit standard Parquet format.
Expand All @@ -196,6 +225,12 @@ def normalize_to_parquet(
silently skipped.
target_partition_bytes: Target size in bytes per output partition.
Used to compute the number of output shards per subdirectory.
max_whitespace_run_chars: Compact any consecutive whitespace run
longer than this many characters down to this length.
Pathologically long whitespace runs (e.g. multi-MB runs from
broken HTML→text extraction, cf. #4588) can OOM downstream
tokenization. Affected records are counted via the
``datakit_normalize_compacted_whitespace`` Zephyr counter.
worker_resources: Per-worker resource request for the Zephyr pipeline.
Defaults to 2 CPU / 16GB RAM / 10GB disk, sized for
``target_partition_bytes`` of 256MB. Scale up when increasing
Expand Down Expand Up @@ -223,7 +258,7 @@ def _run_subdir(subdir: str, files: list[str]) -> None:
num_shards,
)

pipeline = _build_pipeline(files, output_dir, num_shards, text_field, id_field)
pipeline = _build_pipeline(files, output_dir, num_shards, text_field, id_field, max_whitespace_run_chars)
ctx = ZephyrContext(
name=f"normalize-{subdir.replace('/', '-') if subdir else 'all'}",
resources=resources,
Expand All @@ -246,6 +281,7 @@ def normalize_step(
text_field: str = "text",
id_field: str = "id",
target_partition_bytes: int = 256 * 1024 * 1024,
max_whitespace_run_chars: int = DEFAULT_MAX_WHITESPACE_RUN_CHARS,
worker_resources: ResourceConfig | None = None,
override_output_path: str | None = None,
input_path: str | None = None,
Expand Down Expand Up @@ -274,13 +310,15 @@ def normalize_step(
text_field=text_field,
id_field=id_field,
target_partition_bytes=target_partition_bytes,
max_whitespace_run_chars=max_whitespace_run_chars,
worker_resources=worker_resources,
),
deps=[download],
hash_attrs={
"text_field": text_field,
"id_field": id_field,
"target_partition_bytes": target_partition_bytes,
"max_whitespace_run_chars": max_whitespace_run_chars,
"input_path": resolved_input,
},
override_output_path=override_output_path,
Expand Down
29 changes: 29 additions & 0 deletions tests/datakit/test_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,35 @@ def test_skip_existing_idempotent(tmp_path: Path, write_jsonl_gz):
assert parquet_files[0].stat().st_mtime == mtime_first


def test_whitespace_compaction(tmp_path: Path, write_jsonl_gz):
"""Long whitespace runs are compacted, not dropped. Content is preserved."""
input_dir = tmp_path / "input"
output_dir = tmp_path / "output"

records = [
{"id": "normal", "text": "Hello world"},
{"id": "pathological", "text": "before" + " " * 500 + "after"},
{"id": "also_normal", "text": "short spaces are fine"},
]
write_jsonl_gz(input_dir / "data.jsonl.gz", records)

normalize_to_parquet(
input_path=str(input_dir),
output_path=str(output_dir),
max_whitespace_run_chars=100,
)

results = _read_all_parquet(output_dir)
# All three records survive — the pathological one is compacted, not dropped
assert len(results) == 3
by_source = {r["source_id"]: r for r in results}
assert by_source["pathological"]["text"] == "before" + " " * 100 + "after"
# id is recomputed from the compacted text
assert by_source["pathological"]["id"] == generate_id("before" + " " * 100 + "after")
# Normal docs are untouched
assert by_source["normal"]["text"] == "Hello world"


def test_no_input_files_raises(tmp_path: Path):
input_dir = tmp_path / "input"
input_dir.mkdir()
Expand Down
Loading