diff --git a/backend/ee/onyx/main.py b/backend/ee/onyx/main.py index e47a193cda..8ecd6b2588 100644 --- a/backend/ee/onyx/main.py +++ b/backend/ee/onyx/main.py @@ -17,7 +17,9 @@ basic_router as enterprise_settings_router, ) from ee.onyx.server.manage.standard_answer import router as standard_answer_router -from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware +from ee.onyx.server.middleware.tenant_tracking import ( + add_api_server_tenant_id_middleware, +) from ee.onyx.server.oauth.api import router as ee_oauth_router from ee.onyx.server.query_and_chat.chat_backend import ( router as chat_router, @@ -79,7 +81,7 @@ def get_application() -> FastAPI: application = get_application_base(lifespan_override=lifespan) if MULTI_TENANT: - add_tenant_id_middleware(application, logger) + add_api_server_tenant_id_middleware(application, logger) if AUTH_TYPE == AuthType.CLOUD: # For Google OAuth, refresh tokens are requested by: diff --git a/backend/ee/onyx/server/middleware/tenant_tracking.py b/backend/ee/onyx/server/middleware/tenant_tracking.py index efae1fb3e6..390711f6f3 100644 --- a/backend/ee/onyx/server/middleware/tenant_tracking.py +++ b/backend/ee/onyx/server/middleware/tenant_tracking.py @@ -18,11 +18,18 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR -def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None: +def add_api_server_tenant_id_middleware( + app: FastAPI, logger: logging.LoggerAdapter +) -> None: @app.middleware("http") async def set_tenant_id( request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: + """Extracts the tenant id from multiple locations and sets the context var. + + This is very specific to the api server and probably not something you'd want + to use elsewhere. + """ try: if MULTI_TENANT: tenant_id = await _get_tenant_id_from_request(request, logger) diff --git a/backend/model_server/main.py b/backend/model_server/main.py index 4921f9c503..bdd5d42656 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -24,6 +24,7 @@ from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_uvicorn_logger from onyx.utils.middleware import add_onyx_request_id_middleware +from onyx.utils.middleware import add_onyx_tenant_id_middleware from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import MIN_THREADS_ML_MODELS from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST @@ -126,6 +127,7 @@ def get_model_app() -> FastAPI: if INDEXING_ONLY: request_id_prefix = "IDX" + add_onyx_tenant_id_middleware(application, logger) add_onyx_request_id_middleware(application, request_id_prefix, logger) # Initialize and instrument the app diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 850e611c6e..5799acb737 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -58,6 +58,7 @@ ) from onyx.utils.logger import setup_logger from onyx.utils.logger import TaskAttemptSingleton +from onyx.utils.middleware import make_randomized_onyx_request_id from onyx.utils.telemetry import create_milestone_and_report from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType @@ -379,6 +380,7 @@ def _run_indexing( memory_tracer.start() index_attempt_md = IndexAttemptMetadata( + attempt_id=index_attempt_id, connector_id=ctx.connector_id, credential_id=ctx.credential_id, ) @@ -481,6 +483,8 @@ def _run_indexing( batch_description = [] + # Generate an ID that can be used to correlate activity between here + # and the embedding model server doc_batch_cleaned = strip_null_characters(document_batch) for doc in doc_batch_cleaned: batch_description.append(doc.to_short_descriptor()) @@ -502,6 +506,10 @@ def _run_indexing( logger.debug(f"Indexing batch of documents: {batch_description}") + index_attempt_md.request_id = make_randomized_onyx_request_id("CIX") + index_attempt_md.structured_id = ( + f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}" + ) index_attempt_md.batch_num = batch_num + 1 # use 1-index for this # real work happens here! diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index 4fe586897b..1920e0d2f6 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -272,9 +272,14 @@ class SlimDocument(BaseModel): class IndexAttemptMetadata(BaseModel): - batch_num: int | None = None connector_id: int credential_id: int + batch_num: int | None = None + attempt_id: int | None = None + request_id: str | None = None + + # Work in progress: will likely contain metadata about cc pair / index attempt + structured_id: str | None = None class ConnectorCheckpoint(BaseModel): diff --git a/backend/onyx/indexing/chunker.py b/backend/onyx/indexing/chunker.py index 3054378ff3..ffa3052748 100644 --- a/backend/onyx/indexing/chunker.py +++ b/backend/onyx/indexing/chunker.py @@ -135,6 +135,7 @@ def __init__( mini_chunk_size: int = MINI_CHUNK_SIZE, callback: IndexingHeartbeatInterface | None = None, ) -> None: + # from llama_index.core.node_parser import SentenceSplitter from llama_index.text_splitter import SentenceSplitter self.include_metadata = include_metadata diff --git a/backend/onyx/indexing/embedder.py b/backend/onyx/indexing/embedder.py index 78ea96340d..3d5f663dd4 100644 --- a/backend/onyx/indexing/embedder.py +++ b/backend/onyx/indexing/embedder.py @@ -73,6 +73,8 @@ def __init__( def embed_chunks( self, chunks: list[DocAwareChunk], + tenant_id: str | None = None, + request_id: str | None = None, ) -> list[IndexChunk]: raise NotImplementedError @@ -110,6 +112,8 @@ def __init__( def embed_chunks( self, chunks: list[DocAwareChunk], + tenant_id: str | None = None, + request_id: str | None = None, ) -> list[IndexChunk]: """Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well if they exist. If there is no space for it, it would have been thrown out at the chunking step. @@ -143,6 +147,8 @@ def embed_chunks( texts=flat_chunk_texts, text_type=EmbedTextType.PASSAGE, large_chunks_present=large_chunks_present, + tenant_id=tenant_id, + request_id=request_id, ) chunk_titles = { @@ -158,7 +164,10 @@ def embed_chunks( title_embed_dict: dict[str, Embedding] = {} if chunk_titles_list: title_embeddings = self.embedding_model.encode( - chunk_titles_list, text_type=EmbedTextType.PASSAGE + chunk_titles_list, + text_type=EmbedTextType.PASSAGE, + tenant_id=tenant_id, + request_id=request_id, ) title_embed_dict.update( { @@ -190,7 +199,10 @@ def embed_chunks( "Title had to be embedded separately, this should not happen!" ) title_embedding = self.embedding_model.encode( - [title], text_type=EmbedTextType.PASSAGE + [title], + text_type=EmbedTextType.PASSAGE, + tenant_id=tenant_id, + request_id=request_id, )[0] title_embed_dict[title] = title_embedding @@ -231,14 +243,24 @@ def from_db_search_settings( def embed_chunks_with_failure_handling( chunks: list[DocAwareChunk], embedder: IndexingEmbedder, + tenant_id: str | None = None, + request_id: str | None = None, ) -> tuple[list[IndexChunk], list[ConnectorFailure]]: """Tries to embed all chunks in one large batch. If that batch fails for any reason, goes document by document to isolate the failure(s). """ + # TODO(rkuo): this doesn't disambiguate calls to the model server on retries. + # Improve this if needed. + # First try to embed all chunks in one batch try: - return embedder.embed_chunks(chunks=chunks), [] + return ( + embedder.embed_chunks( + chunks=chunks, tenant_id=tenant_id, request_id=request_id + ), + [], + ) except Exception: logger.exception("Failed to embed chunk batch. Trying individual docs.") # wait a couple seconds to let any rate limits or temporary issues resolve @@ -254,7 +276,9 @@ def embed_chunks_with_failure_handling( for doc_id, chunks_for_doc in chunks_by_doc.items(): try: - doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc) + doc_embedded_chunks = embedder.embed_chunks( + chunks=chunks_for_doc, tenant_id=tenant_id, request_id=request_id + ) embedded_chunks.extend(doc_embedded_chunks) except Exception as e: logger.exception(f"Failed to embed chunks for document '{doc_id}'") diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index dbc6ad8c2f..fc94914a6f 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -791,6 +791,8 @@ def index_doc_batch( embed_chunks_with_failure_handling( chunks=chunks, embedder=embedder, + tenant_id=tenant_id, + request_id=index_attempt_metadata.request_id, ) if chunks else ([], []) diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index 8d2dc940da..64f2fc217c 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -3,6 +3,7 @@ from collections.abc import Callable from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor +from functools import partial from functools import wraps from typing import Any @@ -114,10 +115,24 @@ def __init__( model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" - def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: + def _make_model_server_request( + self, + embed_request: EmbedRequest, + tenant_id: str | None = None, + request_id: str | None = None, + ) -> EmbedResponse: def _make_request() -> Response: + headers = {} + if tenant_id: + headers["X-Onyx-Tenant-ID"] = tenant_id + + if request_id: + headers["X-Onyx-Request-ID"] = request_id + response = requests.post( - self.embed_server_endpoint, json=embed_request.model_dump() + self.embed_server_endpoint, + headers=headers, + json=embed_request.model_dump(), ) # signify that this is a rate limit error if response.status_code == 429: @@ -165,6 +180,8 @@ def _batch_encode_texts( batch_size: int, max_seq_length: int, num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS, + tenant_id: str | None = None, + request_id: str | None = None, ) -> list[Embedding]: text_batches = batch_list(texts, batch_size) @@ -175,7 +192,11 @@ def _batch_encode_texts( embeddings: list[Embedding] = [] def process_batch( - batch_idx: int, batch_len: int, text_batch: list[str] + batch_idx: int, + batch_len: int, + text_batch: list[str], + tenant_id: str | None = None, + request_id: str | None = None, ) -> tuple[int, list[Embedding]]: if self.callback: if self.callback.should_stop(): @@ -198,7 +219,9 @@ def process_batch( ) start_time = time.time() - response = self._make_model_server_request(embed_request) + response = self._make_model_server_request( + embed_request, tenant_id=tenant_id, request_id=request_id + ) end_time = time.time() processing_time = end_time - start_time @@ -215,7 +238,16 @@ def process_batch( if num_threads >= 1 and self.provider_type and len(text_batches) > 1: with ThreadPoolExecutor(max_workers=num_threads) as executor: future_to_batch = { - executor.submit(process_batch, idx, len(text_batches), batch): idx + executor.submit( + partial( + process_batch, + idx, + len(text_batches), + batch, + tenant_id=tenant_id, + request_id=request_id, + ) + ): idx for idx, batch in enumerate(text_batches, start=1) } @@ -238,7 +270,13 @@ def process_batch( else: # Original sequential processing for idx, text_batch in enumerate(text_batches, start=1): - _, batch_embeddings = process_batch(idx, len(text_batches), text_batch) + _, batch_embeddings = process_batch( + idx, + len(text_batches), + text_batch, + tenant_id=tenant_id, + request_id=request_id, + ) embeddings.extend(batch_embeddings) if self.callback: self.callback.progress("_batch_encode_texts", 1) @@ -253,6 +291,8 @@ def encode( local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES, max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, + tenant_id: str | None = None, + request_id: str | None = None, ) -> list[Embedding]: if not texts or not all(texts): raise ValueError(f"Empty or missing text for embedding: {texts}") @@ -284,6 +324,8 @@ def encode( text_type=text_type, batch_size=batch_size, max_seq_length=max_seq_length, + tenant_id=tenant_id, + request_id=request_id, ) @classmethod diff --git a/backend/onyx/utils/middleware.py b/backend/onyx/utils/middleware.py index f2fffb5ac8..7db7241161 100644 --- a/backend/onyx/utils/middleware.py +++ b/backend/onyx/utils/middleware.py @@ -11,9 +11,23 @@ from fastapi import Request from fastapi import Response +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR +def add_onyx_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None: + @app.middleware("http") + async def set_tenant_id( + request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: + """Captures and sets the context var for the tenant.""" + + onyx_tenant_id = request.headers.get("X-Onyx-Tenant-ID") + if onyx_tenant_id: + CURRENT_TENANT_ID_CONTEXTVAR.set(onyx_tenant_id) + return await call_next(request) + + def add_onyx_request_id_middleware( app: FastAPI, prefix: str, logger: logging.LoggerAdapter ) -> None: diff --git a/backend/tests/unit/onyx/indexing/test_embedder.py b/backend/tests/unit/onyx/indexing/test_embedder.py index d49d28344e..f7c213e1bf 100644 --- a/backend/tests/unit/onyx/indexing/test_embedder.py +++ b/backend/tests/unit/onyx/indexing/test_embedder.py @@ -88,14 +88,18 @@ def test_default_indexing_embedder_embed_chunks( ) assert result[0].title_embedding == [7.0, 8.0, 9.0] - # Verify the embedding model was called correctly + # Verify the embedding model was called exactly as follows mock_embedding_model.return_value.encode.assert_any_call( texts=[f"Title: {doc_summary}Test chunk{chunk_context}"], text_type=EmbedTextType.PASSAGE, large_chunks_present=False, + tenant_id=None, + request_id=None, ) - # title only embedding call + # Same for title only embedding call mock_embedding_model.return_value.encode.assert_any_call( ["Test Document"], text_type=EmbedTextType.PASSAGE, + tenant_id=None, + request_id=None, )