Skip to content

Commit 97127d9

Browse files
ravwojdyla-agentravwojdylaclaude
authored
tokenize: split into datakit attribute and store stages (#5438)
* part of #2355 * split `processing/tokenize` into datakit-style stage A + stage B with a shared core * `attributes.py` — `tokenize_attributes(NormalizedData) → TokenizedData` (parquet of `{id, input_ids}` per doc, co-partitioned with the source via basename mirroring) * `store_builder.py` — `build_from_datasets` modular core + `build_levanter_store(BuildLevanterStoreConfig) → LevanterStoreData` reading one or more `TokenizedData` sources * `_core.py` — extracted helpers (`attach_id`, `IdPreservingPreprocessor`[^1], `tokenize_pipeline`) * legacy `tokenize(TokenizeConfig)` external signature unchanged; internally composes the same core + `build_from_datasets` in one zephyr context (no parquet round-trip) * `StepSpec` factories `tokenize_attributes_step` and `build_levanter_store_step` mirror the `compute_minhash_attrs_step` pattern * ids reuse from `NormalizedData` when present; fallback computes via `marin.datakit.normalize.generate_id` (xxh3_128, 32 hex) [^1]: levanter's `BatchProcessor` interface explicitly allows non-1:1 input→output; today's processors are 1:1 but the wrapper asserts it loudly so a future packing/SFT processor fails fast instead of silently misaligning ids --------- Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 15711a1 commit 97127d9

5 files changed

Lines changed: 1459 additions & 294 deletions

File tree

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Shared internals for tokenization.
5+
6+
Used by:
7+
* :func:`marin.processing.tokenize.tokenize.tokenize` — legacy raw → Levanter store path.
8+
* :func:`marin.processing.tokenize.attributes.tokenize_attributes` — datakit Stage A
9+
(NormalizedData → attribute parquet).
10+
* :func:`marin.processing.tokenize.store_builder.build_from_datasets` — datakit Stage B
11+
(tokenized records → Levanter store).
12+
13+
Public API lives in those modules; helpers here are package-private.
14+
"""
15+
from __future__ import annotations
16+
17+
import json
18+
import logging
19+
import os
20+
import re
21+
import time
22+
from collections.abc import Iterator, Mapping, Sequence
23+
24+
import braceexpand
25+
import fsspec
26+
import pyarrow.parquet as pq
27+
from levanter.data._preprocessor import BatchProcessor
28+
from levanter.data.text import LmDatasetFormatBase, preprocessor_for_format
29+
from levanter.tokenizers import MarinTokenizer, load_tokenizer
30+
from rigging.filesystem import url_to_fs
31+
from zephyr import Dataset, zephyr_worker_ctx
32+
from zephyr.dataset import FileEntry
33+
from zephyr.readers import InputFileSpec
34+
35+
from marin.datakit.normalize import generate_id
36+
from marin.utils import fsspec_isdir
37+
38+
logger = logging.getLogger(__name__)
39+
40+
MIN_GROUP_BYTES = 100_000_000 # 100 MB floor to avoid degenerate tiny shards
41+
# Empirical upper bound on the zephyr window size (see
42+
# https://github.com/marin-community/marin/issues/2829#issuecomment-3963661943).
43+
_MAX_WINDOW_SIZE = 64
44+
45+
_TOKENIZE_EXTENSIONS = ["json.{gz,zst,zstd}", "jsonl.{gz,zst,zstd}", "parquet"]
46+
47+
# NOTE(chris): Marin's `default_download` writes a `provenance.json` sidecar next to
48+
# downloaded HF data. Downstream tokenize jobs glob those directories and must
49+
# exclude sidecars so we don't train on provenance records.
50+
_MARIN_SIDECAR_NAMES = frozenset({"provenance.json"})
51+
52+
53+
def avg_parquet_row_group_rows(path: str) -> int | None:
54+
"""Return the mean rows-per-row-group from ``path``.
55+
56+
Returns ``None`` if the file has no row groups (empty parquet footer).
57+
"""
58+
fs, resolved = url_to_fs(path)
59+
with fs.open(resolved, "rb") as f:
60+
meta = pq.ParquetFile(f).metadata
61+
if meta.num_row_groups == 0:
62+
return None
63+
return max(1, meta.num_rows // meta.num_row_groups)
64+
65+
66+
def compute_target_group_bytes(total_input_bytes: int, max_workers: int) -> int:
67+
"""Compute target group size to produce approximately ``max_workers`` groups.
68+
69+
Applies a floor of ``MIN_GROUP_BYTES`` to avoid degenerate tiny shards.
70+
"""
71+
return max(total_input_bytes // max_workers, MIN_GROUP_BYTES)
72+
73+
74+
def drop_sidecars(files: list[FileEntry]) -> list[FileEntry]:
75+
return [f for f in files if os.path.basename(f.path) not in _MARIN_SIDECAR_NAMES]
76+
77+
78+
def glob_with_sizes(patterns: list[str]) -> list[FileEntry]:
79+
"""Glob patterns and return FileEntry objects (spec + size).
80+
81+
Uses fsspec ``glob(detail=True)`` which returns file metadata from the same
82+
list-objects API call — no per-file stat RPCs needed. Works for gs://, hf://, s3://, local.
83+
"""
84+
results: list[FileEntry] = []
85+
for pattern in patterns:
86+
pattern = re.sub(r"(?<!:)//+", "/", pattern)
87+
fs, _ = url_to_fs(pattern)
88+
protocol = fsspec.core.split_protocol(pattern)[0]
89+
for expanded in braceexpand.braceexpand(pattern):
90+
detail = fs.glob(expanded, detail=True)
91+
for path, info in detail.items():
92+
full = f"{protocol}://{path}" if protocol else path
93+
results.append(FileEntry(spec=InputFileSpec(path=full), size=info.get("size", 0)))
94+
return results
95+
96+
97+
def expand_tokenize_paths(input_paths: list[str]) -> list[str]:
98+
"""Expand input paths into glob patterns for tokenizable file types.
99+
100+
Directories get expanded to recursive globs for each supported extension.
101+
Concrete paths/patterns pass through unchanged.
102+
"""
103+
patterns: list[str] = []
104+
for path in input_paths:
105+
assert path != "/"
106+
if path.endswith("/") or fsspec_isdir(path):
107+
logger.info(f"Getting all {_TOKENIZE_EXTENSIONS} files in {path}")
108+
for ex in _TOKENIZE_EXTENSIONS:
109+
patterns.append(os.path.join(path, f"**/*.{ex}"))
110+
else:
111+
patterns.append(path)
112+
return patterns
113+
114+
115+
def bundle_files_by_size(files: list[FileEntry], max_bytes: int) -> Iterator[list[str]]:
116+
"""Bundle files into groups, with each group having a total size less than ``max_bytes``."""
117+
current_group: list[str] = []
118+
current_size = 0
119+
120+
for f in files:
121+
if current_size + f.size >= max_bytes and current_group:
122+
yield current_group
123+
current_group = []
124+
current_size = 0
125+
current_group.append(f.path)
126+
current_size += f.size
127+
128+
if current_group:
129+
yield current_group
130+
131+
132+
def attach_id(record: dict, text_field: str = "text") -> dict:
133+
"""Ensure record has an ``id`` field.
134+
135+
If ``id`` is already present and non-null, leave the record unchanged.
136+
Otherwise, generate a deterministic xxh3_128 id via
137+
:func:`marin.datakit.normalize.generate_id` from ``record[text_field]``,
138+
falling back to a JSON serialization of the record if ``text_field`` is
139+
absent.
140+
141+
Datakit-normalized inputs always carry ``id`` and skip the hashing branch.
142+
"""
143+
if record.get("id") is not None:
144+
return record
145+
if text_field in record and record[text_field] is not None:
146+
return {**record, "id": generate_id(str(record[text_field]))}
147+
return {**record, "id": generate_id(json.dumps(record, sort_keys=True, default=str))}
148+
149+
150+
class IdPreservingPreprocessor:
151+
"""Wrap a Levanter ``BatchProcessor`` to thread input ``id`` onto each output.
152+
153+
Levanter's ``BatchProcessor`` interface explicitly allows non-1:1 input→output
154+
(see :class:`levanter.data._preprocessor.BatchProcessor`). All currently used
155+
processors (``BatchTokenizer``, ``ChatProcessor``, ``PrebuiltCacheProcessor``,
156+
``PreferenceChatProcessor``) are 1:1, but a future packing/SFT-splitting
157+
processor would silently misalign ids if we naively zipped. This wrapper
158+
asserts the 1:1 invariant so misalignment fails loudly.
159+
"""
160+
161+
def __init__(self, inner: BatchProcessor):
162+
self.inner = inner
163+
164+
def __call__(self, batch: Sequence[dict]) -> list[dict]:
165+
outputs = self.inner(batch)
166+
# BatchResult is Sequence[U] | Mapping[str, Sequence] (struct-of-arrays)
167+
if isinstance(outputs, Mapping):
168+
keys = list(outputs.keys())
169+
n_out = len(outputs[keys[0]]) if keys else 0
170+
outputs_list: list[dict] = [{k: outputs[k][i] for k in keys} for i in range(n_out)]
171+
else:
172+
outputs_list = list(outputs)
173+
n_out = len(outputs_list)
174+
175+
if n_out != len(batch):
176+
raise RuntimeError(
177+
f"IdPreservingPreprocessor: 1:1 input→output expected, got "
178+
f"{len(batch)} input → {n_out} output records from "
179+
f"{type(self.inner).__name__}. id alignment cannot be preserved; "
180+
"if this processor packs or splits records, route ids via a custom path."
181+
)
182+
183+
return [{**out, "id": rec["id"]} for rec, out in zip(batch, outputs_list, strict=True)]
184+
185+
186+
def tokenize_batches_with_id(
187+
*,
188+
data_format: LmDatasetFormatBase,
189+
batches: Iterator[Sequence[dict]],
190+
) -> Iterator[dict]:
191+
"""Tokenize batches and yield ``{id, input_ids, ...}`` per input doc.
192+
193+
Each input record must already carry ``id`` (apply :func:`attach_id` upstream).
194+
The worker tokenizer config is read from zephyr's shared context — caller is
195+
responsible for ``ctx.put('tokenizer_name', ...)`` and
196+
``ctx.put('tokenizer_backend', ...)`` before pipeline execution.
197+
"""
198+
ctx = zephyr_worker_ctx()
199+
name = ctx.get_shared("tokenizer_name")
200+
backend = ctx.get_shared("tokenizer_backend")
201+
# load_tokenizer is @lru_cache, so this only loads once per worker process.
202+
tokenizer: MarinTokenizer = load_tokenizer(name, backend=backend)
203+
inner = preprocessor_for_format(data_format, tokenizer)
204+
# Levanter's BatchTokenizer ships ``long_string_workaround`` opt-in but the
205+
# behavior is desirable always: per-record texts above ``_workaround_len``
206+
# (10K chars) get split at safe whitespace boundaries before the underlying
207+
# ``encode_batch`` is called, then merged back. No-op for short records.
208+
# Without this, a single multi-MB outlier passes one giant string to the
209+
# Rust tokenizer and OOMs the worker.
210+
if hasattr(inner, "_long_string_workaround"):
211+
inner._long_string_workaround = True
212+
processor = IdPreservingPreprocessor(inner)
213+
214+
batch_count = 0
215+
record_count = 0
216+
token_count = 0
217+
start_time = time.monotonic()
218+
219+
for batch in batches:
220+
batch_count += 1
221+
for record in processor(batch):
222+
record_count += 1
223+
token_count += len(record.get("input_ids", []))
224+
yield record
225+
if batch_count % 10 == 0:
226+
elapsed = time.monotonic() - start_time
227+
tok_per_sec = token_count / elapsed if elapsed > 0 else 0
228+
doc_per_sec = record_count / elapsed if elapsed > 0 else 0
229+
avg_tok_per_doc = token_count / record_count if record_count > 0 else 0
230+
logger.info(
231+
f"Tokenized {batch_count:,} batches, {record_count:,} docs, {token_count:,} tokens "
232+
f"in {elapsed:.1f}s ({tok_per_sec:,.0f} tokens/s, {doc_per_sec:,.1f} docs/s, "
233+
f"{avg_tok_per_doc:,.0f} avg tokens/doc)"
234+
)
235+
236+
elapsed = time.monotonic() - start_time
237+
tok_per_sec = token_count / elapsed if elapsed > 0 else 0
238+
doc_per_sec = record_count / elapsed if elapsed > 0 else 0
239+
avg_tok_per_doc = token_count / record_count if record_count > 0 else 0
240+
logger.info(
241+
f"Tokenization done: {batch_count:,} batches, {record_count:,} docs, {token_count:,} tokens "
242+
f"in {elapsed:.1f}s ({tok_per_sec:,.0f} tokens/s, {doc_per_sec:,.1f} docs/s, "
243+
f"{avg_tok_per_doc:,.0f} avg tokens/doc)"
244+
)
245+
246+
247+
def parquet_window_hint(file_groups: list[list[str]]) -> str | None:
248+
"""Return a sample parquet path from ``file_groups`` if any, else ``None``.
249+
250+
Used to align zephyr's window and Levanter's cache batch with parquet
251+
row-group size on parquet inputs; ignored for non-parquet inputs.
252+
"""
253+
return next((p for group in file_groups for p in group if p.endswith(".parquet")), None)
254+
255+
256+
def resolve_window_and_batch(
257+
sample_parquet_path: str | None,
258+
requested_batch_size: int | None,
259+
) -> tuple[int, int | None]:
260+
"""Pick zephyr window and Levanter batch sizes.
261+
262+
For parquet sources, align both with the parquet row-group size so each unit
263+
of work is exactly one row group end-to-end. Non-parquet inputs fall through
264+
to defaults.
265+
"""
266+
window_size = _MAX_WINDOW_SIZE
267+
batch_size = requested_batch_size
268+
if sample_parquet_path is None:
269+
return window_size, batch_size
270+
avg_rg_rows = avg_parquet_row_group_rows(sample_parquet_path)
271+
if avg_rg_rows is None:
272+
return window_size, batch_size
273+
half_rg = max(avg_rg_rows // 2, 1)
274+
window_size = min(half_rg, _MAX_WINDOW_SIZE)
275+
if requested_batch_size is None:
276+
batch_size = half_rg
277+
logger.info(
278+
"Parquet source: avg rows/row-group=%d (from %s) → window=%d, levanter batch_size=%s",
279+
avg_rg_rows,
280+
sample_parquet_path,
281+
window_size,
282+
batch_size,
283+
)
284+
return window_size, batch_size
285+
286+
287+
def tokenize_pipeline(
288+
ds: Dataset,
289+
*,
290+
data_format: LmDatasetFormatBase,
291+
text_field: str = "text",
292+
sample_count: int | None,
293+
sample_parquet_path: str | None,
294+
levanter_batch_size: int | None,
295+
) -> tuple[Dataset, int | None]:
296+
"""Build the tokenize pipeline tail.
297+
298+
Attaches ``id`` to each input record, optionally subsamples per shard, windows,
299+
and tokenizes. Returns the dataset of ``{id, input_ids, ...}`` records and the
300+
chosen Levanter cache batch size (``None`` keeps Levanter's default).
301+
"""
302+
window_size, batch_size = resolve_window_and_batch(sample_parquet_path, levanter_batch_size)
303+
304+
ds = ds.map(lambda r, tf=text_field: attach_id(r, text_field=tf))
305+
306+
if sample_count is not None:
307+
ds = ds.take_per_shard(sample_count)
308+
309+
return (
310+
ds.window(window_size).map_shard(
311+
lambda batches, _, fmt=data_format: tokenize_batches_with_id(data_format=fmt, batches=batches)
312+
),
313+
batch_size,
314+
)

0 commit comments

Comments
 (0)