Skip to content

Commit 2a9a7f1

Browse files
committed
Added estimated token length controls
1 parent 9467fae commit 2a9a7f1

4 files changed

Lines changed: 116 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
`error`): `truncate` shortens by 20% per retry (`truncate_side`), or `chunk_pool` splits into
1010
`num_chunks` segments and mean-pools. Applied reactively after embed failure; size text upstream
1111
with `splitText` / `processDocuments` when possible. Batch failures fall back per item so the
12-
offending chunk can be identified.
12+
offending chunk can be identified. `max_estimated_tokens` can now pre-truncate inputs with a
13+
lightweight estimate before embedding; `truncate_side` controls both estimated pre-truncation and
14+
reactive `on_token_overflow="truncate"` retries.
1315
- Added model2vec embedding support via `Model2VecEmbeddingAdapter` and `llmEmbed`
1416
source `model2vec`, using in-process static embeddings with offline-ready HF cache
1517
precaching. Install with `pip install talkpipe[model2vec]` or `talkpipe[all]`. Added

docs/guides/model-and-source-configuration.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,12 @@ still too long for the model—not a substitute for upstream chunking.
290290
contiguous parts and mean-pools to one vector per stream item. If a batch embed fails, TalkPipe
291291
retries **per item** so you can see which chunk failed.
292292

293+
**Estimated pre-truncation:** set `max_estimated_tokens` on `llmEmbed` to truncate input before
294+
calling the embedding provider. This uses a lightweight estimate, not the provider tokenizer, so
295+
`on_token_overflow` remains the fallback if the estimate is optimistic. `truncate_side` is shared
296+
by both truncation paths: estimated pre-truncation before the provider call, and
297+
`on_token_overflow="truncate"` after the provider reports a token overflow.
298+
293299
**Batching:** set `batch_size` greater than `1` on `llmEmbed` to call the provider with multiple
294300
texts per request. The stream still has **one input item and one output item per document**;
295301
batching is internal only. `llmEmbed` does **not** accept list-shaped stream items (flatten or
@@ -306,6 +312,7 @@ INPUT FROM echo[data="Hello world"]
306312
```chatterlang
307313
| llmEmbed[on_token_overflow="truncate", truncate_side="tail"]
308314
| llmEmbed[on_token_overflow="chunk_pool", num_chunks=4]
315+
| llmEmbed[max_estimated_tokens=8192, truncate_side="tail"]
309316
```
310317

311318
### RAG and vector pipelines

src/talkpipe/llm/embedding.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Module for embedding text using different models"""
22

3-
from typing import Optional, Annotated, Iterator, Any, List, Literal
43
import logging
4+
import re
5+
from typing import Optional, Annotated, Iterator, Any, List, Literal
56

67
import numpy as np
78

@@ -26,6 +27,13 @@
2627
_MAX_TRUNCATE_ATTEMPTS = 8
2728

2829

30+
def estimate_tokens(text: str) -> int:
31+
"""Estimate token count without using a provider-specific tokenizer."""
32+
chars = len(text)
33+
words = len(re.findall(r"\S+", text))
34+
return int(max(words * 1.3, chars / 4))
35+
36+
2937
class EmbeddingTokenOverflowError(RuntimeError):
3038
"""Raised when embedding fails due to input length and on_token_overflow is error."""
3139

@@ -44,6 +52,10 @@ class LLMEmbed(AbstractFieldSegment):
4452
``error`` (default), ``truncate`` (shrink and retry), or ``chunk_pool`` (split into
4553
``num_chunks`` segments, embed, and mean-pool). Size text before this segment with
4654
upstream chunking when possible.
55+
56+
``max_estimated_tokens`` optionally truncates text before the provider call using
57+
a lightweight estimate, not a tokenizer. ``truncate_side`` controls both that
58+
proactive truncation and reactive ``on_token_overflow="truncate"`` retry behavior.
4759
"""
4860

4961
def __init__(
@@ -66,6 +78,10 @@ def __init__(
6678
int,
6779
"For chunk_pool: number of contiguous segments to split overflow text into",
6880
] = 2,
81+
max_estimated_tokens: Annotated[
82+
Optional[int],
83+
"If set, pre-truncate text to this estimated token budget before embedding",
84+
] = None,
6985
):
7086
"""Initialize the embedding segment with the specified parameters.
7187
@@ -96,18 +112,22 @@ def __init__(
96112
)
97113
if num_chunks < 2:
98114
raise ValueError("num_chunks must be at least 2")
115+
if max_estimated_tokens is not None and max_estimated_tokens < 1:
116+
raise ValueError("max_estimated_tokens must be a positive integer")
99117
self.embedder = getEmbeddingAdapter(source)(model=model)
100118
self.fail_on_error = fail_on_error
101119
self.batch_size = batch_size
102120
self.on_token_overflow = on_token_overflow
103121
self.truncate_side = truncate_side
104122
self.num_chunks = num_chunks
123+
self.max_estimated_tokens = max_estimated_tokens
105124
self._embedding_source = source
106125
self._embedding_model = model
107126

108127
def process_value(self, value: Any) -> List[float]:
109128
"""Embed one extracted field value (AbstractFieldSegment hook)."""
110-
return self._embed_one_with_overflow_policy(None, str(value))
129+
text = self._truncate_to_estimated_token_budget(str(value))
130+
return self._embed_one_with_overflow_policy(None, text)
111131

112132
def _input_value(self, item: Any) -> Any:
113133
"""Extract the value to embed (same rule as AbstractFieldSegment)."""
@@ -137,6 +157,23 @@ def _slice_text(text: str, length: int, side: str) -> str:
137157
return text[start : start + length]
138158
raise ValueError(f"Unknown truncate_side: {side!r}")
139159

160+
def _truncate_to_estimated_token_budget(self, text: str) -> str:
161+
if self.max_estimated_tokens is None:
162+
return text
163+
if estimate_tokens(text) <= self.max_estimated_tokens:
164+
return text
165+
166+
low = 0
167+
high = len(text)
168+
while low < high:
169+
mid = (low + high + 1) // 2
170+
candidate = self._slice_text(text, mid, self.truncate_side)
171+
if estimate_tokens(candidate) <= self.max_estimated_tokens:
172+
low = mid
173+
else:
174+
high = mid - 1
175+
return self._slice_text(text, low, self.truncate_side)
176+
140177
@staticmethod
141178
def _split_num_chunks(text: str, num_chunks: int) -> List[str]:
142179
n = len(text)
@@ -302,7 +339,7 @@ def flush_buffer() -> Iterator[Any]:
302339

303340
self._ensure_scalar_item(item)
304341
logging.debug(f"Processing input item: {item}")
305-
text = str(self._input_value(item))
342+
text = self._truncate_to_estimated_token_budget(str(self._input_value(item)))
306343
logging.debug(f"Embedding text: {text}")
307344

308345
if self.batch_size <= 1:

tests/talkpipe/llm/test_embedding_token_overflow.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from unittest.mock import Mock
55

6-
from talkpipe.llm.embedding import LLMEmbed, EmbeddingTokenOverflowError
6+
from talkpipe.llm.embedding import LLMEmbed, EmbeddingTokenOverflowError, estimate_tokens
77
from talkpipe.llm.embedding_errors import is_token_overflow_error
88

99
TOKEN_OVERFLOW = RuntimeError("maximum context length exceeded")
@@ -15,6 +15,12 @@ def test_is_token_overflow_error_recognizes_common_messages():
1515
assert not is_token_overflow_error(RuntimeError("connection reset"))
1616

1717

18+
def test_estimate_tokens_uses_word_and_character_heuristics():
19+
assert estimate_tokens("") == 0
20+
assert estimate_tokens("one two three four") == 5
21+
assert estimate_tokens("x" * 100) == 25
22+
23+
1824
def _overflow_if_long(max_len: int = 10):
1925
"""Embed succeeds when len(text) <= max_len; otherwise token overflow."""
2026

@@ -88,6 +94,29 @@ def test_on_token_overflow_truncate_single_item():
8894
assert not last_call.startswith("START")
8995

9096

97+
def test_max_estimated_tokens_preemptively_truncates_single_item():
98+
mock = Mock()
99+
mock.execute_one = Mock(side_effect=lambda text: [float(len(text))])
100+
mock.execute_batch = Mock()
101+
embedder = LLMEmbed(
102+
model="test-model",
103+
source="ollama",
104+
max_estimated_tokens=3,
105+
truncate_side="tail",
106+
)
107+
embedder.embedder = mock
108+
long_text = "one two three four five"
109+
110+
result = list(embedder([long_text]))
111+
112+
assert result == [[float(len(mock.execute_one.call_args[0][0]))]]
113+
sent_text = mock.execute_one.call_args[0][0]
114+
assert sent_text != long_text
115+
assert long_text.endswith(sent_text)
116+
assert estimate_tokens(sent_text) <= 3
117+
mock.execute_one.assert_called_once()
118+
119+
91120
def test_on_token_overflow_chunk_pool_single_item():
92121
batch_calls = []
93122

@@ -158,6 +187,37 @@ def test_on_token_overflow_truncate_batch_recovers_middle_item():
158187
mock.execute_batch.assert_called_once()
159188

160189

190+
def test_max_estimated_tokens_preemptively_truncates_batch_items():
191+
batch_calls = []
192+
193+
def execute_batch(texts):
194+
texts = list(texts)
195+
batch_calls.append(texts)
196+
return [[float(len(t))] for t in texts]
197+
198+
mock = Mock()
199+
mock.execute_one = Mock()
200+
mock.execute_batch = Mock(side_effect=execute_batch)
201+
embedder = LLMEmbed(
202+
model="test-model",
203+
source="ollama",
204+
batch_size=2,
205+
max_estimated_tokens=2,
206+
truncate_side="tail",
207+
)
208+
embedder.embedder = mock
209+
long_text = "alpha beta gamma delta"
210+
211+
result = list(embedder([long_text, "ok"]))
212+
213+
assert len(result) == 2
214+
assert batch_calls[0][0] != long_text
215+
assert long_text.endswith(batch_calls[0][0])
216+
assert estimate_tokens(batch_calls[0][0]) <= 2
217+
assert batch_calls[0][1] == "ok"
218+
mock.execute_one.assert_not_called()
219+
220+
161221
def test_on_token_overflow_chunk_pool_batch_recovers_middle_item():
162222
batch_calls = []
163223

@@ -197,3 +257,8 @@ def execute_batch(texts):
197257
def test_num_chunks_must_be_at_least_two():
198258
with pytest.raises(ValueError, match="num_chunks"):
199259
LLMEmbed(model="test-model", source="ollama", num_chunks=1)
260+
261+
262+
def test_max_estimated_tokens_must_be_positive():
263+
with pytest.raises(ValueError, match="max_estimated_tokens"):
264+
LLMEmbed(model="test-model", source="ollama", max_estimated_tokens=0)

0 commit comments

Comments
 (0)