11"""Module for embedding text using different models"""
22
3- from typing import Optional , Annotated , Iterator , Any , List
3+ from typing import Optional , Annotated , Iterator , Any , List , Literal
44import logging
55
6+ import numpy as np
7+
68from talkpipe .pipe .core import AbstractFieldSegment , is_metadata
79from talkpipe .chatterlang .registry import register_segment
810from talkpipe .util .data_manipulation import extract_property , assign_property
911from .config import getEmbeddingAdapter , getEmbeddingSources
12+ from .embedding_errors import is_token_overflow_error
1013from talkpipe .util .config import get_config
1114from talkpipe .util .constants import TALKPIPE_EMBEDDING_MODEL_NAME , TALKPIPE_EMBEDDING_MODEL_SOURCE
1215
1316logger = logging .getLogger (__name__ )
1417
18+ # on_token_overflow mode strings (compare via _OVERFLOW_* constants to avoid Bandit B105/B107)
19+ _ON_TOKEN_OVERFLOW_CHOICES = ("error" , "truncate" , "chunk_pool" )
20+ _OVERFLOW_ERROR , _OVERFLOW_TRUNCATE , _OVERFLOW_CHUNK_POOL = _ON_TOKEN_OVERFLOW_CHOICES
21+ _TRUNCATE_SIDE_CHOICES = ("head" , "tail" , "middle" )
22+
23+ # Truncate retry tuning (not exposed on the segment in v1).
24+ _SHRINK_RATIO = 0.2
25+ _MIN_TRUNCATE_CHARS = 1
26+ _MAX_TRUNCATE_ATTEMPTS = 8
27+
28+
29+ class EmbeddingTokenOverflowError (RuntimeError ):
30+ """Raised when embedding fails due to input length and on_token_overflow is error."""
31+
1532
1633@register_segment ("llmEmbed" )
1734class LLMEmbed (AbstractFieldSegment ):
@@ -22,6 +39,11 @@ class LLMEmbed(AbstractFieldSegment):
2239 :class:`~talkpipe.pipe.core.AbstractFieldSegment`. Batching is internal only
2340 (``batch_size``); use ``makeLists`` upstream only if another segment needs grouped
2441 items—not as direct input to ``llmEmbed``.
42+
43+ When the provider rejects text as too long, ``on_token_overflow`` controls recovery:
44+ ``error`` (default), ``truncate`` (shrink and retry), or ``chunk_pool`` (split into
45+ ``num_chunks`` segments, embed, and mean-pool). Size text before this segment with
46+ upstream chunking when possible.
2547 """
2648
2749 def __init__ (
@@ -32,6 +54,18 @@ def __init__(
3254 set_as : Annotated [Optional [str ], "If provided, append embeddings to input items under this field name" ] = None ,
3355 fail_on_error : Annotated [bool , "Whether to raise an error on failure or to silently ignore it" ] = True ,
3456 batch_size : Annotated [int , "Number of stream items to embed per provider API call" ] = 1 ,
57+ on_token_overflow : Annotated [
58+ Literal ["error" , "truncate" , "chunk_pool" ],
59+ "When embed fails as too long: error, truncate (shrink and retry), or chunk_pool" ,
60+ ] = _OVERFLOW_ERROR ,
61+ truncate_side : Annotated [
62+ Literal ["head" , "tail" , "middle" ],
63+ "For truncate: which portion of the string to keep when shortening" ,
64+ ] = "tail" ,
65+ num_chunks : Annotated [
66+ int ,
67+ "For chunk_pool: number of contiguous segments to split overflow text into" ,
68+ ] = 2 ,
3569 ):
3670 """Initialize the embedding segment with the specified parameters.
3771
@@ -51,13 +85,29 @@ def __init__(
5185 )
5286 if batch_size < 1 :
5387 raise ValueError ("batch_size must be a positive integer" )
88+ if on_token_overflow not in _ON_TOKEN_OVERFLOW_CHOICES :
89+ raise ValueError (
90+ f"on_token_overflow must be one of { _ON_TOKEN_OVERFLOW_CHOICES } , "
91+ f"got { on_token_overflow !r} "
92+ )
93+ if truncate_side not in _TRUNCATE_SIDE_CHOICES :
94+ raise ValueError (
95+ f"truncate_side must be one of { _TRUNCATE_SIDE_CHOICES } , got { truncate_side !r} "
96+ )
97+ if num_chunks < 2 :
98+ raise ValueError ("num_chunks must be at least 2" )
5499 self .embedder = getEmbeddingAdapter (source )(model = model )
55100 self .fail_on_error = fail_on_error
56101 self .batch_size = batch_size
102+ self .on_token_overflow = on_token_overflow
103+ self .truncate_side = truncate_side
104+ self .num_chunks = num_chunks
105+ self ._embedding_source = source
106+ self ._embedding_model = model
57107
58108 def process_value (self , value : Any ) -> List [float ]:
59109 """Embed one extracted field value (AbstractFieldSegment hook)."""
60- return self .embedder . execute_one ( str (value ))
110+ return self ._embed_one_with_overflow_policy ( None , str (value ))
61111
62112 def _input_value (self , item : Any ) -> Any :
63113 """Extract the value to embed (same rule as AbstractFieldSegment)."""
@@ -72,6 +122,127 @@ def _ensure_scalar_item(item: Any) -> None:
72122 "before this segment."
73123 )
74124
125+ @staticmethod
126+ def _slice_text (text : str , length : int , side : str ) -> str :
127+ if length <= 0 :
128+ return ""
129+ if side == "head" :
130+ return text [:length ]
131+ if side == "tail" :
132+ return text [- length :]
133+ if side == "middle" :
134+ if length >= len (text ):
135+ return text
136+ start = (len (text ) - length ) // 2
137+ return text [start : start + length ]
138+ raise ValueError (f"Unknown truncate_side: { side !r} " )
139+
140+ @staticmethod
141+ def _split_num_chunks (text : str , num_chunks : int ) -> List [str ]:
142+ n = len (text )
143+ if num_chunks < 2 or n == 0 :
144+ return [text ] if text else []
145+ return [text [i * n // num_chunks : (i + 1 ) * n // num_chunks ] for i in range (num_chunks )]
146+
147+ @staticmethod
148+ def _mean_pool (vectors : List [List [float ]]) -> List [float ]:
149+ if not vectors :
150+ raise ValueError ("Cannot mean-pool an empty list of vectors" )
151+ arr = np .asarray (vectors , dtype = float )
152+ pooled = arr .mean (axis = 0 )
153+ norm = float (np .linalg .norm (pooled ))
154+ if norm > 0 :
155+ pooled = pooled / norm
156+ return pooled .tolist ()
157+
158+ def _wrap_token_overflow (
159+ self ,
160+ exc : BaseException ,
161+ * ,
162+ item : Any ,
163+ text : str ,
164+ detail : Optional [str ] = None ,
165+ ) -> EmbeddingTokenOverflowError :
166+ field_part = f"field={ self .field !r} , " if self .field else ""
167+ item_part = f"item={ item !r} , " if item is not None else ""
168+ text_len = len (text )
169+ hint = (
170+ "Use smaller upstream chunks (e.g. splitText), "
171+ f"on_token_overflow='truncate', or on_token_overflow='chunk_pool' "
172+ f"(num_chunks={ self .num_chunks } )."
173+ )
174+ extra = f" { detail } " if detail else ""
175+ message = (
176+ f"Embedding input too long for { self ._embedding_source } /{ self ._embedding_model } : "
177+ f"{ item_part } { field_part } text_length={ text_len } . { hint } { extra } "
178+ f"Provider error: { exc } "
179+ )
180+ return EmbeddingTokenOverflowError (message )
181+
182+ def _execute_one_raw (self , text : str ) -> List [float ]:
183+ return self .embedder .execute_one (text )
184+
185+ def _embed_truncate (self , item : Any , text : str ) -> List [float ]:
186+ current = text
187+ last_overflow : Optional [BaseException ] = None
188+ for _ in range (_MAX_TRUNCATE_ATTEMPTS ):
189+ try :
190+ return self ._execute_one_raw (current )
191+ except Exception as exc :
192+ if not is_token_overflow_error (exc ):
193+ raise
194+ last_overflow = exc
195+ n = len (current )
196+ n_next = max (_MIN_TRUNCATE_CHARS , int (n * (1 - _SHRINK_RATIO )))
197+ if n_next >= n :
198+ break
199+ current = self ._slice_text (current , n_next , self .truncate_side )
200+ raise self ._wrap_token_overflow (
201+ last_overflow or RuntimeError ("truncate exhausted" ),
202+ item = item ,
203+ text = text ,
204+ detail = "Truncate retries exhausted." ,
205+ )
206+
207+ def _embed_chunk_pool (self , item : Any , text : str ) -> List [float ]:
208+ segments = self ._split_num_chunks (text , self .num_chunks )
209+ if not segments :
210+ raise self ._wrap_token_overflow (
211+ RuntimeError ("empty text" ),
212+ item = item ,
213+ text = text ,
214+ )
215+ try :
216+ if len (segments ) == 1 :
217+ vectors = [self ._execute_one_raw (segments [0 ])]
218+ else :
219+ vectors = self .embedder .execute_batch (segments )
220+ except Exception as exc :
221+ if is_token_overflow_error (exc ):
222+ raise self ._wrap_token_overflow (
223+ exc ,
224+ item = item ,
225+ text = text ,
226+ detail = (
227+ f"chunk_pool with num_chunks={ self .num_chunks } still exceeded the limit; "
228+ "try a larger num_chunks or smaller upstream chunks."
229+ ),
230+ ) from exc
231+ raise
232+ return self ._mean_pool (vectors )
233+
234+ def _embed_one_with_overflow_policy (self , item : Any , text : str ) -> List [float ]:
235+ try :
236+ return self ._execute_one_raw (text )
237+ except Exception as exc :
238+ if not is_token_overflow_error (exc ):
239+ raise
240+ if self .on_token_overflow == _OVERFLOW_ERROR :
241+ raise self ._wrap_token_overflow (exc , item = item , text = text ) from exc
242+ if self .on_token_overflow == _OVERFLOW_TRUNCATE :
243+ return self ._embed_truncate (item , text )
244+ return self ._embed_chunk_pool (item , text )
245+
75246 def _yield_results (self , item : Any , results : List [Any ]) -> Iterator [Any ]:
76247 """Emit results using AbstractFieldSegment assign/yield semantics."""
77248 for result in results :
@@ -81,37 +252,36 @@ def _yield_results(self, item: Any, results: List[Any]) -> Iterator[Any]:
81252 else :
82253 yield result
83254
84- def _vectors_for_texts (self , texts : List [str ]) -> List [List [float ]]:
85- if not texts :
86- return []
87- if len (texts ) == 1 :
88- return [self .process_value (texts [0 ])]
89- return self .embedder .execute_batch (texts )
255+ def _embed_items_pair (self , items : List [Any ], texts : List [str ]) -> Iterator [Any ]:
256+ for item , text in zip (items , texts ):
257+ try :
258+ vector = self ._embed_one_with_overflow_policy (item , text )
259+ except EmbeddingTokenOverflowError :
260+ raise
261+ except Exception as exc :
262+ logger .error (f"Error during embedding: { exc } " )
263+ if self .fail_on_error :
264+ raise
265+ continue
266+ yield from self ._yield_results (item , [vector ])
90267
91268 def _embed_buffered (self , items : List [Any ], texts : List [str ]) -> Iterator [Any ]:
92269 if not items or not texts :
93270 return
94271 logger .debug (f"Embedding batch of { len (texts )} texts" )
272+ if len (texts ) == 1 :
273+ yield from self ._embed_items_pair (items , texts )
274+ return
95275 try :
96- vectors = self ._vectors_for_texts (texts )
276+ vectors = self .embedder . execute_batch (texts )
97277 for item , vector in zip (items , vectors ):
98278 yield from self ._yield_results (item , [vector ])
99279 except Exception as e :
100280 logger .error (f"Error during batch embedding: { e } " )
101- if self .fail_on_error :
102- raise
103- if len (texts ) == 1 :
104- return
105281 logger .warning (
106282 "Batch embedding failed; falling back to per-item embedding"
107283 )
108- for item , text in zip (items , texts ):
109- try :
110- vector = self .process_value (text )
111- except Exception as item_error :
112- logger .error (f"Error during embedding: { item_error } " )
113- continue
114- yield from self ._yield_results (item , [vector ])
284+ yield from self ._embed_items_pair (items , texts )
115285
116286 def transform (self , input_iter ):
117287 """Transform one stream item at a time; batching is internal only."""
0 commit comments