Skip to content

Commit 9467fae

Browse files
committed
Added support for the situation where an embedder fails because too much text is provided.
1 parent 983be76 commit 9467fae

6 files changed

Lines changed: 447 additions & 24 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
`execute` remains as a deprecated alias for `execute_one` (removed in 1.0). `llmEmbed` extends
66
`AbstractFieldSegment`, uses `batch_size` for internal provider batching only (one stream item
77
in and one out per document), and rejects list-shaped stream items with `TypeError`.
8+
When the provider rejects input as too long, `on_token_overflow` controls recovery (default
9+
`error`): `truncate` shortens by 20% per retry (`truncate_side`), or `chunk_pool` splits into
10+
`num_chunks` segments and mean-pools. Applied reactively after embed failure; size text upstream
11+
with `splitText` / `processDocuments` when possible. Batch failures fall back per item so the
12+
offending chunk can be identified.
813
- Added model2vec embedding support via `Model2VecEmbeddingAdapter` and `llmEmbed`
914
source `model2vec`, using in-process static embeddings with offline-ready HF cache
1015
precaching. Install with `pip install talkpipe[model2vec]` or `talkpipe[all]`. Added

docs/guides/model-and-source-configuration.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,21 +276,38 @@ segment = LLMVisionPrompt(
276276
| `field` | No | Text field to embed on structured items |
277277
| `set_as` | No | Field on the item where the vector is stored |
278278
| `batch_size` | No | Scalar items per provider call (default `1`) |
279-
| `fail_on_error` | No | Default `true` |
279+
| `fail_on_error` | No | Default `true`; applies to non-length failures (network, auth, etc.) |
280+
| `on_token_overflow` | No | Default `error` — when embed fails as too long: `error`, `truncate`, or `chunk_pool` |
281+
| `truncate_side` | No | For `truncate`: `head`, `tail` (default), or `middle` |
282+
| `num_chunks` | No | For `chunk_pool`: segments to split into (default `2`, minimum `2`) |
283+
284+
**Sizing text:** Chunk or split documents **before** `llmEmbed` (e.g. `splitText`, `processDocuments`,
285+
`makevectordatabase --chunk_size`). `on_token_overflow` is **failure recovery** when a chunk is
286+
still too long for the model—not a substitute for upstream chunking.
287+
288+
**Token overflow:** TalkPipe classifies provider “too long” errors and applies `on_token_overflow`.
289+
`truncate` retries with 20% shorter character slices per attempt; `chunk_pool` embeds `num_chunks`
290+
contiguous parts and mean-pools to one vector per stream item. If a batch embed fails, TalkPipe
291+
retries **per item** so you can see which chunk failed.
280292

281293
**Batching:** set `batch_size` greater than `1` on `llmEmbed` to call the provider with multiple
282294
texts per request. The stream still has **one input item and one output item per document**;
283295
batching is internal only. `llmEmbed` does **not** accept list-shaped stream items (flatten or
284296
emit items individually upstream). `field` and `set_as` follow
285-
`AbstractFieldSegment` on each scalar item. With `fail_on_error=False`, failed items are skipped
286-
when a buffered batch falls back to per-item embedding.
297+
`AbstractFieldSegment` on each scalar item. With `fail_on_error=False`, non-length failures skip
298+
items when per-item fallback runs after a batch failure.
287299

288300
```chatterlang
289301
INPUT FROM echo[data="Hello world"]
290302
| llmEmbed[model="mxbai-embed-large", source="ollama", set_as="vector"]
291303
| print
292304
```
293305

306+
```chatterlang
307+
| llmEmbed[on_token_overflow="truncate", truncate_side="tail"]
308+
| llmEmbed[on_token_overflow="chunk_pool", num_chunks=4]
309+
```
310+
294311
### RAG and vector pipelines
295312

296313
Higher-level segments forward model settings to inner LLM segments:

src/talkpipe/app/chatterlang_workbench.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,11 @@ def main():
11331133

11341134
# Expose the workbench's own logo URL so example scripts can fetch it from
11351135
# the running server via $workbench_logo_url.
1136-
logo_host = "localhost" if args.host in ("0.0.0.0", "::") else args.host
1136+
logo_host = (
1137+
"localhost"
1138+
if args.host in ("0.0.0.0", "::") # nosec B104 - compare bind host, not binding here
1139+
else args.host
1140+
)
11371141
add_config_values(
11381142
{"workbench_logo_url": f"http://{logo_host}:{args.port}/static/talkpipe_logo.png"},
11391143
override=True,

src/talkpipe/llm/embedding.py

Lines changed: 190 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
11
"""Module for embedding text using different models"""
22

3-
from typing import Optional, Annotated, Iterator, Any, List
3+
from typing import Optional, Annotated, Iterator, Any, List, Literal
44
import logging
55

6+
import numpy as np
7+
68
from talkpipe.pipe.core import AbstractFieldSegment, is_metadata
79
from talkpipe.chatterlang.registry import register_segment
810
from talkpipe.util.data_manipulation import extract_property, assign_property
911
from .config import getEmbeddingAdapter, getEmbeddingSources
12+
from .embedding_errors import is_token_overflow_error
1013
from talkpipe.util.config import get_config
1114
from talkpipe.util.constants import TALKPIPE_EMBEDDING_MODEL_NAME, TALKPIPE_EMBEDDING_MODEL_SOURCE
1215

1316
logger = logging.getLogger(__name__)
1417

18+
# on_token_overflow mode strings (compare via _OVERFLOW_* constants to avoid Bandit B105/B107)
19+
_ON_TOKEN_OVERFLOW_CHOICES = ("error", "truncate", "chunk_pool")
20+
_OVERFLOW_ERROR, _OVERFLOW_TRUNCATE, _OVERFLOW_CHUNK_POOL = _ON_TOKEN_OVERFLOW_CHOICES
21+
_TRUNCATE_SIDE_CHOICES = ("head", "tail", "middle")
22+
23+
# Truncate retry tuning (not exposed on the segment in v1).
24+
_SHRINK_RATIO = 0.2
25+
_MIN_TRUNCATE_CHARS = 1
26+
_MAX_TRUNCATE_ATTEMPTS = 8
27+
28+
29+
class EmbeddingTokenOverflowError(RuntimeError):
30+
"""Raised when embedding fails due to input length and on_token_overflow is error."""
31+
1532

1633
@register_segment("llmEmbed")
1734
class LLMEmbed(AbstractFieldSegment):
@@ -22,6 +39,11 @@ class LLMEmbed(AbstractFieldSegment):
2239
:class:`~talkpipe.pipe.core.AbstractFieldSegment`. Batching is internal only
2340
(``batch_size``); use ``makeLists`` upstream only if another segment needs grouped
2441
items—not as direct input to ``llmEmbed``.
42+
43+
When the provider rejects text as too long, ``on_token_overflow`` controls recovery:
44+
``error`` (default), ``truncate`` (shrink and retry), or ``chunk_pool`` (split into
45+
``num_chunks`` segments, embed, and mean-pool). Size text before this segment with
46+
upstream chunking when possible.
2547
"""
2648

2749
def __init__(
@@ -32,6 +54,18 @@ def __init__(
3254
set_as: Annotated[Optional[str], "If provided, append embeddings to input items under this field name"] = None,
3355
fail_on_error: Annotated[bool, "Whether to raise an error on failure or to silently ignore it"] = True,
3456
batch_size: Annotated[int, "Number of stream items to embed per provider API call"] = 1,
57+
on_token_overflow: Annotated[
58+
Literal["error", "truncate", "chunk_pool"],
59+
"When embed fails as too long: error, truncate (shrink and retry), or chunk_pool",
60+
] = _OVERFLOW_ERROR,
61+
truncate_side: Annotated[
62+
Literal["head", "tail", "middle"],
63+
"For truncate: which portion of the string to keep when shortening",
64+
] = "tail",
65+
num_chunks: Annotated[
66+
int,
67+
"For chunk_pool: number of contiguous segments to split overflow text into",
68+
] = 2,
3569
):
3670
"""Initialize the embedding segment with the specified parameters.
3771
@@ -51,13 +85,29 @@ def __init__(
5185
)
5286
if batch_size < 1:
5387
raise ValueError("batch_size must be a positive integer")
88+
if on_token_overflow not in _ON_TOKEN_OVERFLOW_CHOICES:
89+
raise ValueError(
90+
f"on_token_overflow must be one of {_ON_TOKEN_OVERFLOW_CHOICES}, "
91+
f"got {on_token_overflow!r}"
92+
)
93+
if truncate_side not in _TRUNCATE_SIDE_CHOICES:
94+
raise ValueError(
95+
f"truncate_side must be one of {_TRUNCATE_SIDE_CHOICES}, got {truncate_side!r}"
96+
)
97+
if num_chunks < 2:
98+
raise ValueError("num_chunks must be at least 2")
5499
self.embedder = getEmbeddingAdapter(source)(model=model)
55100
self.fail_on_error = fail_on_error
56101
self.batch_size = batch_size
102+
self.on_token_overflow = on_token_overflow
103+
self.truncate_side = truncate_side
104+
self.num_chunks = num_chunks
105+
self._embedding_source = source
106+
self._embedding_model = model
57107

58108
def process_value(self, value: Any) -> List[float]:
59109
"""Embed one extracted field value (AbstractFieldSegment hook)."""
60-
return self.embedder.execute_one(str(value))
110+
return self._embed_one_with_overflow_policy(None, str(value))
61111

62112
def _input_value(self, item: Any) -> Any:
63113
"""Extract the value to embed (same rule as AbstractFieldSegment)."""
@@ -72,6 +122,127 @@ def _ensure_scalar_item(item: Any) -> None:
72122
"before this segment."
73123
)
74124

125+
@staticmethod
126+
def _slice_text(text: str, length: int, side: str) -> str:
127+
if length <= 0:
128+
return ""
129+
if side == "head":
130+
return text[:length]
131+
if side == "tail":
132+
return text[-length:]
133+
if side == "middle":
134+
if length >= len(text):
135+
return text
136+
start = (len(text) - length) // 2
137+
return text[start : start + length]
138+
raise ValueError(f"Unknown truncate_side: {side!r}")
139+
140+
@staticmethod
141+
def _split_num_chunks(text: str, num_chunks: int) -> List[str]:
142+
n = len(text)
143+
if num_chunks < 2 or n == 0:
144+
return [text] if text else []
145+
return [text[i * n // num_chunks : (i + 1) * n // num_chunks] for i in range(num_chunks)]
146+
147+
@staticmethod
148+
def _mean_pool(vectors: List[List[float]]) -> List[float]:
149+
if not vectors:
150+
raise ValueError("Cannot mean-pool an empty list of vectors")
151+
arr = np.asarray(vectors, dtype=float)
152+
pooled = arr.mean(axis=0)
153+
norm = float(np.linalg.norm(pooled))
154+
if norm > 0:
155+
pooled = pooled / norm
156+
return pooled.tolist()
157+
158+
def _wrap_token_overflow(
159+
self,
160+
exc: BaseException,
161+
*,
162+
item: Any,
163+
text: str,
164+
detail: Optional[str] = None,
165+
) -> EmbeddingTokenOverflowError:
166+
field_part = f"field={self.field!r}, " if self.field else ""
167+
item_part = f"item={item!r}, " if item is not None else ""
168+
text_len = len(text)
169+
hint = (
170+
"Use smaller upstream chunks (e.g. splitText), "
171+
f"on_token_overflow='truncate', or on_token_overflow='chunk_pool' "
172+
f"(num_chunks={self.num_chunks})."
173+
)
174+
extra = f" {detail}" if detail else ""
175+
message = (
176+
f"Embedding input too long for {self._embedding_source}/{self._embedding_model}: "
177+
f"{item_part}{field_part}text_length={text_len}. {hint}{extra} "
178+
f"Provider error: {exc}"
179+
)
180+
return EmbeddingTokenOverflowError(message)
181+
182+
def _execute_one_raw(self, text: str) -> List[float]:
183+
return self.embedder.execute_one(text)
184+
185+
def _embed_truncate(self, item: Any, text: str) -> List[float]:
186+
current = text
187+
last_overflow: Optional[BaseException] = None
188+
for _ in range(_MAX_TRUNCATE_ATTEMPTS):
189+
try:
190+
return self._execute_one_raw(current)
191+
except Exception as exc:
192+
if not is_token_overflow_error(exc):
193+
raise
194+
last_overflow = exc
195+
n = len(current)
196+
n_next = max(_MIN_TRUNCATE_CHARS, int(n * (1 - _SHRINK_RATIO)))
197+
if n_next >= n:
198+
break
199+
current = self._slice_text(current, n_next, self.truncate_side)
200+
raise self._wrap_token_overflow(
201+
last_overflow or RuntimeError("truncate exhausted"),
202+
item=item,
203+
text=text,
204+
detail="Truncate retries exhausted.",
205+
)
206+
207+
def _embed_chunk_pool(self, item: Any, text: str) -> List[float]:
208+
segments = self._split_num_chunks(text, self.num_chunks)
209+
if not segments:
210+
raise self._wrap_token_overflow(
211+
RuntimeError("empty text"),
212+
item=item,
213+
text=text,
214+
)
215+
try:
216+
if len(segments) == 1:
217+
vectors = [self._execute_one_raw(segments[0])]
218+
else:
219+
vectors = self.embedder.execute_batch(segments)
220+
except Exception as exc:
221+
if is_token_overflow_error(exc):
222+
raise self._wrap_token_overflow(
223+
exc,
224+
item=item,
225+
text=text,
226+
detail=(
227+
f"chunk_pool with num_chunks={self.num_chunks} still exceeded the limit; "
228+
"try a larger num_chunks or smaller upstream chunks."
229+
),
230+
) from exc
231+
raise
232+
return self._mean_pool(vectors)
233+
234+
def _embed_one_with_overflow_policy(self, item: Any, text: str) -> List[float]:
235+
try:
236+
return self._execute_one_raw(text)
237+
except Exception as exc:
238+
if not is_token_overflow_error(exc):
239+
raise
240+
if self.on_token_overflow == _OVERFLOW_ERROR:
241+
raise self._wrap_token_overflow(exc, item=item, text=text) from exc
242+
if self.on_token_overflow == _OVERFLOW_TRUNCATE:
243+
return self._embed_truncate(item, text)
244+
return self._embed_chunk_pool(item, text)
245+
75246
def _yield_results(self, item: Any, results: List[Any]) -> Iterator[Any]:
76247
"""Emit results using AbstractFieldSegment assign/yield semantics."""
77248
for result in results:
@@ -81,37 +252,36 @@ def _yield_results(self, item: Any, results: List[Any]) -> Iterator[Any]:
81252
else:
82253
yield result
83254

84-
def _vectors_for_texts(self, texts: List[str]) -> List[List[float]]:
85-
if not texts:
86-
return []
87-
if len(texts) == 1:
88-
return [self.process_value(texts[0])]
89-
return self.embedder.execute_batch(texts)
255+
def _embed_items_pair(self, items: List[Any], texts: List[str]) -> Iterator[Any]:
256+
for item, text in zip(items, texts):
257+
try:
258+
vector = self._embed_one_with_overflow_policy(item, text)
259+
except EmbeddingTokenOverflowError:
260+
raise
261+
except Exception as exc:
262+
logger.error(f"Error during embedding: {exc}")
263+
if self.fail_on_error:
264+
raise
265+
continue
266+
yield from self._yield_results(item, [vector])
90267

91268
def _embed_buffered(self, items: List[Any], texts: List[str]) -> Iterator[Any]:
92269
if not items or not texts:
93270
return
94271
logger.debug(f"Embedding batch of {len(texts)} texts")
272+
if len(texts) == 1:
273+
yield from self._embed_items_pair(items, texts)
274+
return
95275
try:
96-
vectors = self._vectors_for_texts(texts)
276+
vectors = self.embedder.execute_batch(texts)
97277
for item, vector in zip(items, vectors):
98278
yield from self._yield_results(item, [vector])
99279
except Exception as e:
100280
logger.error(f"Error during batch embedding: {e}")
101-
if self.fail_on_error:
102-
raise
103-
if len(texts) == 1:
104-
return
105281
logger.warning(
106282
"Batch embedding failed; falling back to per-item embedding"
107283
)
108-
for item, text in zip(items, texts):
109-
try:
110-
vector = self.process_value(text)
111-
except Exception as item_error:
112-
logger.error(f"Error during embedding: {item_error}")
113-
continue
114-
yield from self._yield_results(item, [vector])
284+
yield from self._embed_items_pair(items, texts)
115285

116286
def transform(self, input_iter):
117287
"""Transform one stream item at a time; batching is internal only."""
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Classify embedding provider errors related to input length / token limits."""
2+
3+
from __future__ import annotations
4+
5+
import re
6+
7+
# Substrings commonly seen when an embedding input exceeds model limits.
8+
_TOKEN_OVERFLOW_PATTERNS = tuple(
9+
re.compile(p, re.IGNORECASE)
10+
for p in (
11+
r"maximum context length",
12+
r"context length",
13+
r"token limit",
14+
r"too many tokens",
15+
r"input.*too long",
16+
r"reduce (the )?(length|size|your input)",
17+
r"exceeds the maximum",
18+
r"max.*tokens",
19+
)
20+
)
21+
22+
23+
def is_token_overflow_error(exc: BaseException) -> bool:
24+
"""Return True if the exception likely indicates input text was too long to embed."""
25+
message = str(exc)
26+
if not message:
27+
message = repr(exc)
28+
return any(pattern.search(message) for pattern in _TOKEN_OVERFLOW_PATTERNS)

0 commit comments

Comments
 (0)