11"""Module for embedding text using different models"""
22
3- from typing import Optional , Annotated , Iterator , Any , List , Literal
43import logging
4+ import re
5+ from typing import Optional , Annotated , Iterator , Any , List , Literal
56
67import numpy as np
78
2627_MAX_TRUNCATE_ATTEMPTS = 8
2728
2829
30+ def estimate_tokens (text : str ) -> int :
31+ """Estimate token count without using a provider-specific tokenizer."""
32+ chars = len (text )
33+ words = len (re .findall (r"\S+" , text ))
34+ return int (max (words * 1.3 , chars / 4 ))
35+
36+
2937class EmbeddingTokenOverflowError (RuntimeError ):
3038 """Raised when embedding fails due to input length and on_token_overflow is error."""
3139
@@ -44,6 +52,10 @@ class LLMEmbed(AbstractFieldSegment):
4452 ``error`` (default), ``truncate`` (shrink and retry), or ``chunk_pool`` (split into
4553 ``num_chunks`` segments, embed, and mean-pool). Size text before this segment with
4654 upstream chunking when possible.
55+
56+ ``max_estimated_tokens`` optionally truncates text before the provider call using
57+ a lightweight estimate, not a tokenizer. ``truncate_side`` controls both that
58+ proactive truncation and reactive ``on_token_overflow="truncate"`` retry behavior.
4759 """
4860
4961 def __init__ (
@@ -66,6 +78,10 @@ def __init__(
6678 int ,
6779 "For chunk_pool: number of contiguous segments to split overflow text into" ,
6880 ] = 2 ,
81+ max_estimated_tokens : Annotated [
82+ Optional [int ],
83+ "If set, pre-truncate text to this estimated token budget before embedding" ,
84+ ] = None ,
6985 ):
7086 """Initialize the embedding segment with the specified parameters.
7187
@@ -96,18 +112,22 @@ def __init__(
96112 )
97113 if num_chunks < 2 :
98114 raise ValueError ("num_chunks must be at least 2" )
115+ if max_estimated_tokens is not None and max_estimated_tokens < 1 :
116+ raise ValueError ("max_estimated_tokens must be a positive integer" )
99117 self .embedder = getEmbeddingAdapter (source )(model = model )
100118 self .fail_on_error = fail_on_error
101119 self .batch_size = batch_size
102120 self .on_token_overflow = on_token_overflow
103121 self .truncate_side = truncate_side
104122 self .num_chunks = num_chunks
123+ self .max_estimated_tokens = max_estimated_tokens
105124 self ._embedding_source = source
106125 self ._embedding_model = model
107126
108127 def process_value (self , value : Any ) -> List [float ]:
109128 """Embed one extracted field value (AbstractFieldSegment hook)."""
110- return self ._embed_one_with_overflow_policy (None , str (value ))
129+ text = self ._truncate_to_estimated_token_budget (str (value ))
130+ return self ._embed_one_with_overflow_policy (None , text )
111131
112132 def _input_value (self , item : Any ) -> Any :
113133 """Extract the value to embed (same rule as AbstractFieldSegment)."""
@@ -137,6 +157,23 @@ def _slice_text(text: str, length: int, side: str) -> str:
137157 return text [start : start + length ]
138158 raise ValueError (f"Unknown truncate_side: { side !r} " )
139159
160+ def _truncate_to_estimated_token_budget (self , text : str ) -> str :
161+ if self .max_estimated_tokens is None :
162+ return text
163+ if estimate_tokens (text ) <= self .max_estimated_tokens :
164+ return text
165+
166+ low = 0
167+ high = len (text )
168+ while low < high :
169+ mid = (low + high + 1 ) // 2
170+ candidate = self ._slice_text (text , mid , self .truncate_side )
171+ if estimate_tokens (candidate ) <= self .max_estimated_tokens :
172+ low = mid
173+ else :
174+ high = mid - 1
175+ return self ._slice_text (text , low , self .truncate_side )
176+
140177 @staticmethod
141178 def _split_num_chunks (text : str , num_chunks : int ) -> List [str ]:
142179 n = len (text )
@@ -302,7 +339,7 @@ def flush_buffer() -> Iterator[Any]:
302339
303340 self ._ensure_scalar_item (item )
304341 logging .debug (f"Processing input item: { item } " )
305- text = str (self ._input_value (item ))
342+ text = self . _truncate_to_estimated_token_budget ( str (self ._input_value (item ) ))
306343 logging .debug (f"Embedding text: { text } " )
307344
308345 if self .batch_size <= 1 :
0 commit comments