33from collections .abc import Callable
44from concurrent .futures import as_completed
55from concurrent .futures import ThreadPoolExecutor
6+ from functools import partial
67from functools import wraps
78from typing import Any
89
@@ -114,10 +115,24 @@ def __init__(
114115 model_server_url = build_model_server_url (server_host , server_port )
115116 self .embed_server_endpoint = f"{ model_server_url } /encoder/bi-encoder-embed"
116117
117- def _make_model_server_request (self , embed_request : EmbedRequest ) -> EmbedResponse :
118+ def _make_model_server_request (
119+ self ,
120+ embed_request : EmbedRequest ,
121+ tenant_id : str | None = None ,
122+ request_id : str | None = None ,
123+ ) -> EmbedResponse :
118124 def _make_request () -> Response :
125+ headers = {}
126+ if tenant_id :
127+ headers ["X-Onyx-Tenant-ID" ] = tenant_id
128+
129+ if request_id :
130+ headers ["X-Onyx-Request-ID" ] = request_id
131+
119132 response = requests .post (
120- self .embed_server_endpoint , json = embed_request .model_dump ()
133+ self .embed_server_endpoint ,
134+ headers = headers ,
135+ json = embed_request .model_dump (),
121136 )
122137 # signify that this is a rate limit error
123138 if response .status_code == 429 :
@@ -165,6 +180,8 @@ def _batch_encode_texts(
165180 batch_size : int ,
166181 max_seq_length : int ,
167182 num_threads : int = INDEXING_EMBEDDING_MODEL_NUM_THREADS ,
183+ tenant_id : str | None = None ,
184+ request_id : str | None = None ,
168185 ) -> list [Embedding ]:
169186 text_batches = batch_list (texts , batch_size )
170187
@@ -175,7 +192,11 @@ def _batch_encode_texts(
175192 embeddings : list [Embedding ] = []
176193
177194 def process_batch (
178- batch_idx : int , batch_len : int , text_batch : list [str ]
195+ batch_idx : int ,
196+ batch_len : int ,
197+ text_batch : list [str ],
198+ tenant_id : str | None = None ,
199+ request_id : str | None = None ,
179200 ) -> tuple [int , list [Embedding ]]:
180201 if self .callback :
181202 if self .callback .should_stop ():
@@ -198,7 +219,9 @@ def process_batch(
198219 )
199220
200221 start_time = time .time ()
201- response = self ._make_model_server_request (embed_request )
222+ response = self ._make_model_server_request (
223+ embed_request , tenant_id = tenant_id , request_id = request_id
224+ )
202225 end_time = time .time ()
203226
204227 processing_time = end_time - start_time
@@ -215,7 +238,16 @@ def process_batch(
215238 if num_threads >= 1 and self .provider_type and len (text_batches ) > 1 :
216239 with ThreadPoolExecutor (max_workers = num_threads ) as executor :
217240 future_to_batch = {
218- executor .submit (process_batch , idx , len (text_batches ), batch ): idx
241+ executor .submit (
242+ partial (
243+ process_batch ,
244+ idx ,
245+ len (text_batches ),
246+ batch ,
247+ tenant_id = tenant_id ,
248+ request_id = request_id ,
249+ )
250+ ): idx
219251 for idx , batch in enumerate (text_batches , start = 1 )
220252 }
221253
@@ -238,7 +270,13 @@ def process_batch(
238270 else :
239271 # Original sequential processing
240272 for idx , text_batch in enumerate (text_batches , start = 1 ):
241- _ , batch_embeddings = process_batch (idx , len (text_batches ), text_batch )
273+ _ , batch_embeddings = process_batch (
274+ idx ,
275+ len (text_batches ),
276+ text_batch ,
277+ tenant_id = tenant_id ,
278+ request_id = request_id ,
279+ )
242280 embeddings .extend (batch_embeddings )
243281 if self .callback :
244282 self .callback .progress ("_batch_encode_texts" , 1 )
@@ -253,6 +291,8 @@ def encode(
253291 local_embedding_batch_size : int = BATCH_SIZE_ENCODE_CHUNKS ,
254292 api_embedding_batch_size : int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES ,
255293 max_seq_length : int = DOC_EMBEDDING_CONTEXT_SIZE ,
294+ tenant_id : str | None = None ,
295+ request_id : str | None = None ,
256296 ) -> list [Embedding ]:
257297 if not texts or not all (texts ):
258298 raise ValueError (f"Empty or missing text for embedding: { texts } " )
@@ -284,6 +324,8 @@ def encode(
284324 text_type = text_type ,
285325 batch_size = batch_size ,
286326 max_seq_length = max_seq_length ,
327+ tenant_id = tenant_id ,
328+ request_id = request_id ,
287329 )
288330
289331 @classmethod
0 commit comments