Skip to content

pass through various id's and log them in the model server for better… #4485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backend/ee/onyx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
basic_router as enterprise_settings_router,
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
)
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
Expand Down Expand Up @@ -79,7 +81,7 @@ def get_application() -> FastAPI:
application = get_application_base(lifespan_override=lifespan)

if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
add_api_server_tenant_id_middleware(application, logger)

if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
Expand Down
9 changes: 8 additions & 1 deletion backend/ee/onyx/server/middleware/tenant_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,18 @@
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR


def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
def add_api_server_tenant_id_middleware(
app: FastAPI, logger: logging.LoggerAdapter
) -> None:
@app.middleware("http")
async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Extracts the tenant id from multiple locations and sets the context var.

This is very specific to the api server and probably not something you'd want
to use elsewhere.
"""
try:
if MULTI_TENANT:
tenant_id = await _get_tenant_id_from_request(request, logger)
Expand Down
2 changes: 2 additions & 0 deletions backend/model_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.middleware import add_onyx_tenant_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
Expand Down Expand Up @@ -126,6 +127,7 @@ def get_model_app() -> FastAPI:
if INDEXING_ONLY:
request_id_prefix = "IDX"

add_onyx_tenant_id_middleware(application, logger)
add_onyx_request_id_middleware(application, request_id_prefix, logger)

# Initialize and instrument the app
Expand Down
8 changes: 8 additions & 0 deletions backend/onyx/background/indexing/run_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
Expand Down Expand Up @@ -379,6 +380,7 @@ def _run_indexing(
memory_tracer.start()

index_attempt_md = IndexAttemptMetadata(
attempt_id=index_attempt_id,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
Expand Down Expand Up @@ -481,6 +483,8 @@ def _run_indexing(

batch_description = []

# Generate an ID that can be used to correlate activity between here
# and the embedding model server
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
Expand All @@ -502,6 +506,10 @@ def _run_indexing(

logger.debug(f"Indexing batch of documents: {batch_description}")

index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
index_attempt_md.structured_id = (
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
)
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this

# real work happens here!
Expand Down
7 changes: 6 additions & 1 deletion backend/onyx/connectors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,14 @@ class SlimDocument(BaseModel):


class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
connector_id: int
credential_id: int
batch_num: int | None = None
attempt_id: int | None = None
request_id: str | None = None

# Work in progress: will likely contain metadata about cc pair / index attempt
structured_id: str | None = None


class ConnectorCheckpoint(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions backend/onyx/indexing/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
mini_chunk_size: int = MINI_CHUNK_SIZE,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
# from llama_index.core.node_parser import SentenceSplitter
from llama_index.text_splitter import SentenceSplitter

self.include_metadata = include_metadata
Expand Down
32 changes: 28 additions & 4 deletions backend/onyx/indexing/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(
def embed_chunks(
self,
chunks: list[DocAwareChunk],
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[IndexChunk]:
raise NotImplementedError

Expand Down Expand Up @@ -110,6 +112,8 @@ def __init__(
def embed_chunks(
self,
chunks: list[DocAwareChunk],
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[IndexChunk]:
"""Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well
if they exist. If there is no space for it, it would have been thrown out at the chunking step.
Expand Down Expand Up @@ -143,6 +147,8 @@ def embed_chunks(
texts=flat_chunk_texts,
text_type=EmbedTextType.PASSAGE,
large_chunks_present=large_chunks_present,
tenant_id=tenant_id,
request_id=request_id,
)

chunk_titles = {
Expand All @@ -158,7 +164,10 @@ def embed_chunks(
title_embed_dict: dict[str, Embedding] = {}
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
chunk_titles_list, text_type=EmbedTextType.PASSAGE
chunk_titles_list,
text_type=EmbedTextType.PASSAGE,
tenant_id=tenant_id,
request_id=request_id,
)
title_embed_dict.update(
{
Expand Down Expand Up @@ -190,7 +199,10 @@ def embed_chunks(
"Title had to be embedded separately, this should not happen!"
)
title_embedding = self.embedding_model.encode(
[title], text_type=EmbedTextType.PASSAGE
[title],
text_type=EmbedTextType.PASSAGE,
tenant_id=tenant_id,
request_id=request_id,
)[0]
title_embed_dict[title] = title_embedding

Expand Down Expand Up @@ -231,14 +243,24 @@ def from_db_search_settings(
def embed_chunks_with_failure_handling(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str | None = None,
request_id: str | None = None,
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
"""

# TODO(rkuo): this doesn't disambiguate calls to the model server on retries.
# Improve this if needed.

# First try to embed all chunks in one batch
try:
return embedder.embed_chunks(chunks=chunks), []
return (
embedder.embed_chunks(
chunks=chunks, tenant_id=tenant_id, request_id=request_id
),
[],
)
except Exception:
logger.exception("Failed to embed chunk batch. Trying individual docs.")
# wait a couple seconds to let any rate limits or temporary issues resolve
Expand All @@ -254,7 +276,9 @@ def embed_chunks_with_failure_handling(

for doc_id, chunks_for_doc in chunks_by_doc.items():
try:
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
doc_embedded_chunks = embedder.embed_chunks(
chunks=chunks_for_doc, tenant_id=tenant_id, request_id=request_id
)
embedded_chunks.extend(doc_embedded_chunks)
except Exception as e:
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
Expand Down
2 changes: 2 additions & 0 deletions backend/onyx/indexing/indexing_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,8 @@ def index_doc_batch(
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=index_attempt_metadata.request_id,
)
if chunks
else ([], [])
Expand Down
54 changes: 48 additions & 6 deletions backend/onyx/natural_language_processing/search_nlp_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from functools import wraps
from typing import Any

Expand Down Expand Up @@ -114,10 +115,24 @@ def __init__(
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"

def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_model_server_request(
self,
embed_request: EmbedRequest,
tenant_id: str | None = None,
request_id: str | None = None,
) -> EmbedResponse:
def _make_request() -> Response:
headers = {}
if tenant_id:
headers["X-Onyx-Tenant-ID"] = tenant_id

if request_id:
headers["X-Onyx-Request-ID"] = request_id

response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
self.embed_server_endpoint,
headers=headers,
json=embed_request.model_dump(),
)
# signify that this is a rate limit error
if response.status_code == 429:
Expand Down Expand Up @@ -165,6 +180,8 @@ def _batch_encode_texts(
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)

Expand All @@ -175,7 +192,11 @@ def _batch_encode_texts(
embeddings: list[Embedding] = []

def process_batch(
batch_idx: int, batch_len: int, text_batch: list[str]
batch_idx: int,
batch_len: int,
text_batch: list[str],
tenant_id: str | None = None,
request_id: str | None = None,
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
Expand All @@ -198,7 +219,9 @@ def process_batch(
)

start_time = time.time()
response = self._make_model_server_request(embed_request)
response = self._make_model_server_request(
embed_request, tenant_id=tenant_id, request_id=request_id
)
end_time = time.time()

processing_time = end_time - start_time
Expand All @@ -215,7 +238,16 @@ def process_batch(
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, len(text_batches), batch): idx
executor.submit(
partial(
process_batch,
idx,
len(text_batches),
batch,
tenant_id=tenant_id,
request_id=request_id,
)
): idx
for idx, batch in enumerate(text_batches, start=1)
}

Expand All @@ -238,7 +270,13 @@ def process_batch(
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, len(text_batches), text_batch)
_, batch_embeddings = process_batch(
idx,
len(text_batches),
text_batch,
tenant_id=tenant_id,
request_id=request_id,
)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
Expand All @@ -253,6 +291,8 @@ def encode(
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
Expand Down Expand Up @@ -284,6 +324,8 @@ def encode(
text_type=text_type,
batch_size=batch_size,
max_seq_length=max_seq_length,
tenant_id=tenant_id,
request_id=request_id,
)

@classmethod
Expand Down
14 changes: 14 additions & 0 deletions backend/onyx/utils/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,23 @@
from fastapi import Request
from fastapi import Response

from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR


def add_onyx_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
@app.middleware("http")
async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Captures and sets the context var for the tenant."""

onyx_tenant_id = request.headers.get("X-Onyx-Tenant-ID")
if onyx_tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.set(onyx_tenant_id)
return await call_next(request)


def add_onyx_request_id_middleware(
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
) -> None:
Expand Down
8 changes: 6 additions & 2 deletions backend/tests/unit/onyx/indexing/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,18 @@ def test_default_indexing_embedder_embed_chunks(
)
assert result[0].title_embedding == [7.0, 8.0, 9.0]

# Verify the embedding model was called correctly
# Verify the embedding model was called exactly as follows
mock_embedding_model.return_value.encode.assert_any_call(
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
tenant_id=None,
request_id=None,
)
# title only embedding call
# Same for title only embedding call
mock_embedding_model.return_value.encode.assert_any_call(
["Test Document"],
text_type=EmbedTextType.PASSAGE,
tenant_id=None,
request_id=None,
)
Loading