Skip to content

Commit 5bae69b

Browse files
[marin] datakit/normalize: compact pathological whitespace runs (#4603)
## Summary - Add a `max_whitespace_run_chars` option (default 128) to the datakit normalization step. Consecutive whitespace runs exceeding the limit are truncated to that length — **preserving the surrounding content** rather than dropping the entire document. - Handles broken HTML→text extraction artifacts (e.g. multi-MB space runs, cf. #4588) that can OOM downstream tokenization, while keeping the actual useful text. - Affected records are counted via a new Zephyr counter, `datakit_normalize_compacted_whitespace`. Document `id` is recomputed after compaction. - Follow-up to #4600 (which caps homogeneous runs inside the tokenizer). Pass `max_whitespace_run_chars=None` to disable. ## Test plan - [x] `tests/datakit/test_normalize.py` — new cases: verifies compaction preserves content and recomputes id; `None` disables compaction. Existing 10 tests still pass. - [x] `./infra/pre-commit.py --fix` clean. Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com>
1 parent ba15a6d commit 5bae69b

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

lib/marin/src/marin/datakit/normalize.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818
import os
19+
import re
1920
from collections.abc import Callable, Iterator
2021
from concurrent.futures import ThreadPoolExecutor, as_completed
2122
from dataclasses import dataclass, field
@@ -32,6 +33,17 @@
3233

3334
logger = logging.getLogger(__name__)
3435

36+
# Default cap on the longest consecutive whitespace run in a document.
37+
# Runs exceeding this are compacted to this length at normalization time.
38+
# Pathologically long whitespace runs (e.g. multi-MB runs from broken
39+
# HTML→text extraction, cf. #4588) can OOM downstream tokenization.
40+
# 128 matches the longest whitespace run that Llama's tokenizer collapses
41+
# into a single token, so capping here is lossless for that tokenizer.
42+
DEFAULT_MAX_WHITESPACE_RUN_CHARS = 128
43+
44+
# Counter name for documents that had whitespace runs compacted.
45+
COMPACTED_WHITESPACE_COUNTER = "datakit_normalize_compacted_whitespace"
46+
3547

3648
class DedupMode(StrEnum):
3749
"""How aggressively to deduplicate records during normalization.
@@ -183,13 +195,35 @@ def _compute_total_bytes(file_paths: list[str]) -> int:
183195
return total
184196

185197

198+
def _make_whitespace_compactor(max_whitespace_run_chars: int) -> Callable[[dict[str, Any]], dict[str, Any]]:
199+
"""Return a map function that compacts consecutive whitespace runs exceeding the limit.
200+
201+
Any run of whitespace longer than *max_whitespace_run_chars* is truncated to
202+
that length (preserving the original whitespace characters). Affected records
203+
are counted via the ``COMPACTED_WHITESPACE_COUNTER`` Zephyr counter, and the
204+
``id`` is recomputed to reflect the new text.
205+
"""
206+
pattern = re.compile(r"\s{" + str(max_whitespace_run_chars + 1) + r",}")
207+
208+
def compact(record: dict[str, Any]) -> dict[str, Any]:
209+
text = record["text"]
210+
compacted = pattern.sub(lambda m: m.group(0)[:max_whitespace_run_chars], text)
211+
if len(compacted) != len(text):
212+
counters.increment(COMPACTED_WHITESPACE_COUNTER)
213+
record = {**record, "text": compacted, "id": generate_id(compacted)}
214+
return record
215+
216+
return compact
217+
218+
186219
def _build_pipeline(
187220
files: list[str],
188221
output_dir: str,
189222
num_shards: int,
190223
text_field: str,
191224
id_field: str | None,
192225
dedup_mode: DedupMode,
226+
max_whitespace_run_chars: int,
193227
) -> Dataset:
194228
"""Build a single Zephyr pipeline for one subdirectory."""
195229
normalize_record = _make_normalize_fn(text_field, id_field)
@@ -221,6 +255,7 @@ def has_text(record: dict[str, Any]) -> bool:
221255
.flat_map(load_file)
222256
.filter(has_text)
223257
.map(normalize_record)
258+
.map(_make_whitespace_compactor(max_whitespace_run_chars))
224259
.group_by(
225260
key=lambda r: int(r["id"], 16) % num_shards,
226261
reducer=reducers[dedup_mode],
@@ -241,6 +276,7 @@ def normalize_to_parquet(
241276
text_field: str = "text",
242277
id_field: str = "id",
243278
target_partition_bytes: int = 256 * 1024 * 1024,
279+
max_whitespace_run_chars: int = DEFAULT_MAX_WHITESPACE_RUN_CHARS,
244280
worker_resources: ResourceConfig | None = None,
245281
file_extensions: tuple[str, ...] | None = None,
246282
dedup_mode: DedupMode = DedupMode.EXACT,
@@ -262,6 +298,12 @@ def normalize_to_parquet(
262298
silently skipped.
263299
target_partition_bytes: Target size in bytes per output partition.
264300
Used to compute the number of output shards per subdirectory.
301+
max_whitespace_run_chars: Compact any consecutive whitespace run
302+
longer than this many characters down to this length.
303+
Pathologically long whitespace runs (e.g. multi-MB runs from
304+
broken HTML→text extraction, cf. #4588) can OOM downstream
305+
tokenization. Affected records are counted via the
306+
``datakit_normalize_compacted_whitespace`` Zephyr counter.
265307
worker_resources: Per-worker resource request for the Zephyr pipeline.
266308
Defaults to 2 CPU / 16GB RAM / 10GB disk, sized for
267309
``target_partition_bytes`` of 256MB. Scale up when increasing
@@ -300,7 +342,15 @@ def _run_subdir(subdir: str, files: list[str]) -> NormalizeSubdirResult:
300342
num_shards,
301343
)
302344

303-
pipeline = _build_pipeline(files, output_dir, num_shards, text_field, id_field, dedup_mode)
345+
pipeline = _build_pipeline(
346+
files,
347+
output_dir,
348+
num_shards,
349+
text_field,
350+
id_field,
351+
dedup_mode,
352+
max_whitespace_run_chars,
353+
)
304354
ctx = ZephyrContext(
305355
name=f"normalize-{subdir.replace('/', '-') if subdir else 'all'}",
306356
resources=resources,
@@ -343,6 +393,7 @@ def normalize_step(
343393
text_field: str = "text",
344394
id_field: str = "id",
345395
target_partition_bytes: int = 256 * 1024 * 1024,
396+
max_whitespace_run_chars: int = DEFAULT_MAX_WHITESPACE_RUN_CHARS,
346397
worker_resources: ResourceConfig | None = None,
347398
override_output_path: str | None = None,
348399
input_path: str | None = None,
@@ -378,6 +429,7 @@ def normalize_step(
378429
text_field=text_field,
379430
id_field=id_field,
380431
target_partition_bytes=target_partition_bytes,
432+
max_whitespace_run_chars=max_whitespace_run_chars,
381433
worker_resources=worker_resources,
382434
file_extensions=file_extensions,
383435
dedup_mode=dedup_mode,
@@ -387,6 +439,7 @@ def normalize_step(
387439
"text_field": text_field,
388440
"id_field": id_field,
389441
"target_partition_bytes": target_partition_bytes,
442+
"max_whitespace_run_chars": max_whitespace_run_chars,
390443
"input_path": resolved_input,
391444
"file_extensions": file_extensions,
392445
"dedup_mode": dedup_mode,

tests/datakit/test_normalize.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,35 @@ def test_skip_existing_idempotent(tmp_path: Path, write_jsonl_gz):
203203
assert parquet_files[0].stat().st_mtime == mtime_first
204204

205205

206+
def test_whitespace_compaction(tmp_path: Path, write_jsonl_gz):
207+
"""Long whitespace runs are compacted, not dropped. Content is preserved."""
208+
input_dir = tmp_path / "input"
209+
output_dir = tmp_path / "output"
210+
211+
records = [
212+
{"id": "normal", "text": "Hello world"},
213+
{"id": "pathological", "text": "before" + " " * 500 + "after"},
214+
{"id": "also_normal", "text": "short spaces are fine"},
215+
]
216+
write_jsonl_gz(input_dir / "data.jsonl.gz", records)
217+
218+
normalize_to_parquet(
219+
input_path=str(input_dir),
220+
output_path=str(output_dir),
221+
max_whitespace_run_chars=100,
222+
)
223+
224+
results = _read_all_parquet(output_dir)
225+
# All three records survive — the pathological one is compacted, not dropped
226+
assert len(results) == 3
227+
by_source = {r["source_id"]: r for r in results}
228+
assert by_source["pathological"]["text"] == "before" + " " * 100 + "after"
229+
# id is recomputed from the compacted text
230+
assert by_source["pathological"]["id"] == generate_id("before" + " " * 100 + "after")
231+
# Normal docs are untouched
232+
assert by_source["normal"]["text"] == "Hello world"
233+
234+
206235
def test_no_input_files_raises(tmp_path: Path):
207236
input_dir = tmp_path / "input"
208237
input_dir.mkdir()

0 commit comments

Comments
 (0)