diff --git a/lib/levanter/src/levanter/data/text/_batch_tokenizer.py b/lib/levanter/src/levanter/data/text/_batch_tokenizer.py index 0363c484c3..d4918f5810 100644 --- a/lib/levanter/src/levanter/data/text/_batch_tokenizer.py +++ b/lib/levanter/src/levanter/data/text/_batch_tokenizer.py @@ -1,7 +1,6 @@ # Copyright The Levanter Authors # SPDX-License-Identifier: Apache-2.0 -from itertools import chain from typing import Sequence, Any import regex @@ -13,6 +12,12 @@ LONG_STRING_WORKAROUND = 10_000 ws = regex.compile(r"\s") +# When the long-string workaround triggers, encode each over-long text in +# sub-batches of this many pieces. Caps in-flight memory at one sub-batch +# of input strings + their tokenized output, instead of holding all pieces +# from all records simultaneously. +_LONG_STRING_BATCH_SIZE = 256 + class BatchTokenizer(BatchProcessor[dict, dict]): """ @@ -46,30 +51,34 @@ def __init__( self._workaround_len = _workaround_len def __call__(self, batch: Sequence[dict]) -> list[dict]: - batch_text = [example[self.text_field] for example in batch] - - if self._append_eos: - eos = self.tokenizer.eos_token - assert eos is not None - batch_text = [d + " " + eos for d in batch_text] - - if self._long_string_workaround: - batch_text, needs_merge = self._break_for_long_sequences(batch_text) - else: - needs_merge = [] - - encoded = self.tokenizer.encode_batch(batch_text, add_special_tokens=False) - + bos_id = self.tokenizer.bos_token_id if self._append_bos else None + eos_str = self.tokenizer.eos_token if self._append_eos else None if self._append_bos: - bos_id = self.tokenizer.bos_token_id assert bos_id is not None - if needs_merge: - # Prepend BOS only to first chunks so the merged doc has a single BOS. - encoded = [[bos_id, *enc] if not merge else enc for enc, merge in zip(encoded, needs_merge)] + if self._append_eos: + assert eos_str is not None + + # Encode per-record so an outlier's pieces never coexist with the + # rest of the batch's encodings in memory. Short records take the + # one-shot path; long records stream through ``_encode_long_string``, + # which sub-batches splits and accumulates ids in-place. + encoded: list[list[int]] = [] + for example in batch: + text = example[self.text_field] + if eos_str is not None: + text = text + " " + eos_str + + if self._long_string_workaround and len(text) > self._workaround_len: + ids = self._encode_long_string(text) else: - encoded = [[bos_id, *enc] for enc in encoded] + ids = self.tokenizer.encode(text, add_special_tokens=False) + + if bos_id is not None: + # In-place prepend: O(n) shift but no extra full-list allocation, + # unlike ``[bos_id, *ids]`` which doubles peak for huge ids. + ids.insert(0, bos_id) + encoded.append(ids) - # Build a dict-of-lists structure analogous to the old BatchEncoding. encoding: dict[str, list] = {"input_ids": encoded} if self.return_attention_mask: @@ -80,29 +89,41 @@ def __call__(self, batch: Sequence[dict]) -> list[dict]: encoding, self.max_length, self.padding, pad_token_id=self.tokenizer.pad_token_id or 0 ) - if needs_merge: - encoding = self._merge_split_encodings(batch_text, encoding, needs_merge) - unbatched = [dict(zip(encoding, t)) for t in zip(*[encoding[k] for k in encoding])] return unbatched - def _break_for_long_sequences(self, batch: Sequence[str]): - orig_lengths = [len(d) for d in batch] - orig_batch = batch - batch_out: list[str] = [] - needs_merge: list[bool] = [] - for i, d in enumerate(orig_batch): - needs_merge.append(False) - orig_len = orig_lengths[i] - while len(d) > self._workaround_len: - match = ws.search(d, self._workaround_len) - split = match.start() if match is not None else len(d) - batch_out.append(d[:split]) - needs_merge.append(True) - d = d[split:] - orig_len -= split - batch_out.append(d) - return batch_out, needs_merge + def _encode_long_string(self, text: str) -> list[int]: + """Encode one over-long text by splitting at safe whitespace boundaries + and concatenating ids in-place. + + Splits are buffered in groups of ``_LONG_STRING_BATCH_SIZE`` pieces; + each group is passed through ``encode_batch`` and the resulting ids + are extended into the running ``ids`` list before the next group is + produced. Peak in-flight memory is one sub-batch's input strings + + tokens, regardless of how long the original text is. + """ + ids: list[int] = [] + pieces: list[str] = [] + remaining = text + while True: + if len(remaining) > self._workaround_len: + match = ws.search(remaining, self._workaround_len) + split = match.start() if match is not None else len(remaining) + pieces.append(remaining[:split]) + remaining = remaining[split:] + else: + pieces.append(remaining) + remaining = "" + + if len(pieces) >= _LONG_STRING_BATCH_SIZE or not remaining: + for encoded_piece in self.tokenizer.encode_batch(pieces, add_special_tokens=False): + ids.extend(encoded_piece) + pieces.clear() + + if not remaining: + break + + return ids @property def metadata(self) -> dict[str, Any]: @@ -146,25 +167,6 @@ def num_gpus(self) -> int: return self.override_resources.get("num_gpus", 0) return 0 - @staticmethod - def _merge_split_encodings(batch, encoding, needs_merge): - new_encoding = {} - for k, v in encoding.items(): - if len(v) == 0: - continue - v_out = [] - vs_to_merge: list[list[int]] = [] - for i in range(len(batch)): - if not needs_merge[i]: - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - vs_to_merge = [] - vs_to_merge.append(v[i]) - if len(vs_to_merge) > 0: - v_out.append(list(chain(*vs_to_merge))) - new_encoding[k] = v_out - return new_encoding - def _apply_padding_and_truncation( encoding: dict[str, list[list[int]]], max_length: int, padding, pad_token_id: int = 0 diff --git a/lib/marin/src/marin/processing/tokenize/tokenize.py b/lib/marin/src/marin/processing/tokenize/tokenize.py index 7680c4a3f9..57bdf87120 100644 --- a/lib/marin/src/marin/processing/tokenize/tokenize.py +++ b/lib/marin/src/marin/processing/tokenize/tokenize.py @@ -305,6 +305,13 @@ def _tokenize_batches(*, config: TokenizeConfig | HfTokenizeConfig, batches: Ite # load_tokenizer is @lru_cache, so this only loads once per worker process. tokenizer: MarinTokenizer = load_tokenizer(name, backend=backend) batch_processor = preprocessor_for_format(config.format, tokenizer) + # Levanter's BatchTokenizer ships ``long_string_workaround`` opt-in but the + # behavior is desirable always: per-record texts above ``_workaround_len`` + # (10K chars) get split at safe whitespace boundaries before the underlying + # ``encode_batch`` is called, then merged back. No-op for short records. + # Without this, a single multi-MB outlier passes one giant string to the + # Rust tokenizer and OOMs the worker. + batch_processor._long_string_workaround = True batch_count = 0 record_count = 0