Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 62 additions & 60 deletions lib/levanter/src/levanter/data/text/_batch_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright The Levanter Authors
# SPDX-License-Identifier: Apache-2.0

from itertools import chain
from typing import Sequence, Any

import regex
Expand All @@ -13,6 +12,12 @@
LONG_STRING_WORKAROUND = 10_000
ws = regex.compile(r"\s")

# When the long-string workaround triggers, encode each over-long text in
# sub-batches of this many pieces. Caps in-flight memory at one sub-batch
# of input strings + their tokenized output, instead of holding all pieces
# from all records simultaneously.
_LONG_STRING_BATCH_SIZE = 256


class BatchTokenizer(BatchProcessor[dict, dict]):
"""
Expand Down Expand Up @@ -46,30 +51,34 @@ def __init__(
self._workaround_len = _workaround_len

def __call__(self, batch: Sequence[dict]) -> list[dict]:
batch_text = [example[self.text_field] for example in batch]

if self._append_eos:
eos = self.tokenizer.eos_token
assert eos is not None
batch_text = [d + " " + eos for d in batch_text]

if self._long_string_workaround:
batch_text, needs_merge = self._break_for_long_sequences(batch_text)
else:
needs_merge = []

encoded = self.tokenizer.encode_batch(batch_text, add_special_tokens=False)

bos_id = self.tokenizer.bos_token_id if self._append_bos else None
eos_str = self.tokenizer.eos_token if self._append_eos else None
if self._append_bos:
bos_id = self.tokenizer.bos_token_id
assert bos_id is not None
if needs_merge:
# Prepend BOS only to first chunks so the merged doc has a single BOS.
encoded = [[bos_id, *enc] if not merge else enc for enc, merge in zip(encoded, needs_merge)]
if self._append_eos:
assert eos_str is not None

# Encode per-record so an outlier's pieces never coexist with the
# rest of the batch's encodings in memory. Short records take the
# one-shot path; long records stream through ``_encode_long_string``,
# which sub-batches splits and accumulates ids in-place.
encoded: list[list[int]] = []
for example in batch:
text = example[self.text_field]
if eos_str is not None:
text = text + " " + eos_str

if self._long_string_workaround and len(text) > self._workaround_len:
ids = self._encode_long_string(text)
else:
encoded = [[bos_id, *enc] for enc in encoded]
ids = self.tokenizer.encode(text, add_special_tokens=False)

if bos_id is not None:
# In-place prepend: O(n) shift but no extra full-list allocation,
# unlike ``[bos_id, *ids]`` which doubles peak for huge ids.
ids.insert(0, bos_id)
encoded.append(ids)

# Build a dict-of-lists structure analogous to the old BatchEncoding.
encoding: dict[str, list] = {"input_ids": encoded}

if self.return_attention_mask:
Expand All @@ -80,29 +89,41 @@ def __call__(self, batch: Sequence[dict]) -> list[dict]:
encoding, self.max_length, self.padding, pad_token_id=self.tokenizer.pad_token_id or 0
)

if needs_merge:
encoding = self._merge_split_encodings(batch_text, encoding, needs_merge)

unbatched = [dict(zip(encoding, t)) for t in zip(*[encoding[k] for k in encoding])]
return unbatched

def _break_for_long_sequences(self, batch: Sequence[str]):
orig_lengths = [len(d) for d in batch]
orig_batch = batch
batch_out: list[str] = []
needs_merge: list[bool] = []
for i, d in enumerate(orig_batch):
needs_merge.append(False)
orig_len = orig_lengths[i]
while len(d) > self._workaround_len:
match = ws.search(d, self._workaround_len)
split = match.start() if match is not None else len(d)
batch_out.append(d[:split])
needs_merge.append(True)
d = d[split:]
orig_len -= split
batch_out.append(d)
return batch_out, needs_merge
def _encode_long_string(self, text: str) -> list[int]:
"""Encode one over-long text by splitting at safe whitespace boundaries
and concatenating ids in-place.

Splits are buffered in groups of ``_LONG_STRING_BATCH_SIZE`` pieces;
each group is passed through ``encode_batch`` and the resulting ids
are extended into the running ``ids`` list before the next group is
produced. Peak in-flight memory is one sub-batch's input strings +
tokens, regardless of how long the original text is.
"""
ids: list[int] = []
pieces: list[str] = []
remaining = text
while True:
if len(remaining) > self._workaround_len:
match = ws.search(remaining, self._workaround_len)
split = match.start() if match is not None else len(remaining)
pieces.append(remaining[:split])
remaining = remaining[split:]
else:
pieces.append(remaining)
remaining = ""

if len(pieces) >= _LONG_STRING_BATCH_SIZE or not remaining:
for encoded_piece in self.tokenizer.encode_batch(pieces, add_special_tokens=False):
ids.extend(encoded_piece)
pieces.clear()

if not remaining:
break

return ids

@property
def metadata(self) -> dict[str, Any]:
Expand Down Expand Up @@ -146,25 +167,6 @@ def num_gpus(self) -> int:
return self.override_resources.get("num_gpus", 0)
return 0

@staticmethod
def _merge_split_encodings(batch, encoding, needs_merge):
new_encoding = {}
for k, v in encoding.items():
if len(v) == 0:
continue
v_out = []
vs_to_merge: list[list[int]] = []
for i in range(len(batch)):
if not needs_merge[i]:
if len(vs_to_merge) > 0:
v_out.append(list(chain(*vs_to_merge)))
vs_to_merge = []
vs_to_merge.append(v[i])
if len(vs_to_merge) > 0:
v_out.append(list(chain(*vs_to_merge)))
new_encoding[k] = v_out
return new_encoding


def _apply_padding_and_truncation(
encoding: dict[str, list[list[int]]], max_length: int, padding, pad_token_id: int = 0
Expand Down
7 changes: 7 additions & 0 deletions lib/marin/src/marin/processing/tokenize/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ def _tokenize_batches(*, config: TokenizeConfig | HfTokenizeConfig, batches: Ite
# load_tokenizer is @lru_cache, so this only loads once per worker process.
tokenizer: MarinTokenizer = load_tokenizer(name, backend=backend)
batch_processor = preprocessor_for_format(config.format, tokenizer)
# Levanter's BatchTokenizer ships ``long_string_workaround`` opt-in but the
# behavior is desirable always: per-record texts above ``_workaround_len``
# (10K chars) get split at safe whitespace boundaries before the underlying
# ``encode_batch`` is called, then merged back. No-op for short records.
# Without this, a single multi-MB outlier passes one giant string to the
# Rust tokenizer and OOMs the worker.
batch_processor._long_string_workaround = True

batch_count = 0
record_count = 0
Expand Down
Loading