Skip to content

Commit e1d7582

Browse files
author
Richard Kuo (Onyx)
committed
pass through various id's and log them in the model server for better tracking
1 parent 9b6c762 commit e1d7582

File tree

10 files changed

+120
-14
lines changed

10 files changed

+120
-14
lines changed

Diff for: backend/ee/onyx/main.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
basic_router as enterprise_settings_router,
1818
)
1919
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
20-
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
20+
from ee.onyx.server.middleware.tenant_tracking import (
21+
add_api_server_tenant_id_middleware,
22+
)
2123
from ee.onyx.server.oauth.api import router as ee_oauth_router
2224
from ee.onyx.server.query_and_chat.chat_backend import (
2325
router as chat_router,
@@ -79,7 +81,7 @@ def get_application() -> FastAPI:
7981
application = get_application_base(lifespan_override=lifespan)
8082

8183
if MULTI_TENANT:
82-
add_tenant_id_middleware(application, logger)
84+
add_api_server_tenant_id_middleware(application, logger)
8385

8486
if AUTH_TYPE == AuthType.CLOUD:
8587
# For Google OAuth, refresh tokens are requested by:

Diff for: backend/ee/onyx/server/middleware/tenant_tracking.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,18 @@
1818
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
1919

2020

21-
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
21+
def add_api_server_tenant_id_middleware(
22+
app: FastAPI, logger: logging.LoggerAdapter
23+
) -> None:
2224
@app.middleware("http")
2325
async def set_tenant_id(
2426
request: Request, call_next: Callable[[Request], Awaitable[Response]]
2527
) -> Response:
28+
"""Extracts the tenant id from multiple locations and sets the context var.
29+
30+
This is very specific to the api server and probably not something you'd want
31+
to use elsewhere.
32+
"""
2633
try:
2734
if MULTI_TENANT:
2835
tenant_id = await _get_tenant_id_from_request(request, logger)

Diff for: backend/model_server/main.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from onyx.utils.logger import setup_logger
2525
from onyx.utils.logger import setup_uvicorn_logger
2626
from onyx.utils.middleware import add_onyx_request_id_middleware
27+
from onyx.utils.middleware import add_onyx_tenant_id_middleware
2728
from shared_configs.configs import INDEXING_ONLY
2829
from shared_configs.configs import MIN_THREADS_ML_MODELS
2930
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@@ -126,6 +127,7 @@ def get_model_app() -> FastAPI:
126127
if INDEXING_ONLY:
127128
request_id_prefix = "IDX"
128129

130+
add_onyx_tenant_id_middleware(application, logger)
129131
add_onyx_request_id_middleware(application, request_id_prefix, logger)
130132

131133
# Initialize and instrument the app

Diff for: backend/onyx/background/indexing/run_indexing.py

+8
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from onyx.utils.logger import setup_logger
6060
from onyx.utils.logger import TaskAttemptSingleton
61+
from onyx.utils.middleware import make_randomized_onyx_request_id
6162
from onyx.utils.telemetry import create_milestone_and_report
6263
from onyx.utils.telemetry import optional_telemetry
6364
from onyx.utils.telemetry import RecordType
@@ -379,6 +380,7 @@ def _run_indexing(
379380
memory_tracer.start()
380381

381382
index_attempt_md = IndexAttemptMetadata(
383+
attempt_id=index_attempt_id,
382384
connector_id=ctx.connector_id,
383385
credential_id=ctx.credential_id,
384386
)
@@ -481,6 +483,8 @@ def _run_indexing(
481483

482484
batch_description = []
483485

486+
# Generate an ID that can be used to correlate activity between here
487+
# and the embedding model server
484488
doc_batch_cleaned = strip_null_characters(document_batch)
485489
for doc in doc_batch_cleaned:
486490
batch_description.append(doc.to_short_descriptor())
@@ -502,6 +506,10 @@ def _run_indexing(
502506

503507
logger.debug(f"Indexing batch of documents: {batch_description}")
504508

509+
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
510+
index_attempt_md.structured_id = (
511+
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
512+
)
505513
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
506514

507515
# real work happens here!

Diff for: backend/onyx/connectors/models.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,14 @@ class SlimDocument(BaseModel):
272272

273273

274274
class IndexAttemptMetadata(BaseModel):
275-
batch_num: int | None = None
276275
connector_id: int
277276
credential_id: int
277+
batch_num: int | None = None
278+
attempt_id: int | None = None
279+
request_id: str | None = None
280+
281+
# Work in progress: will likely contain metadata about cc pair / index attempt
282+
structured_id: str | None = None
278283

279284

280285
class ConnectorCheckpoint(BaseModel):

Diff for: backend/onyx/indexing/chunker.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(
135135
mini_chunk_size: int = MINI_CHUNK_SIZE,
136136
callback: IndexingHeartbeatInterface | None = None,
137137
) -> None:
138+
# from llama_index.core.node_parser import SentenceSplitter
138139
from llama_index.text_splitter import SentenceSplitter
139140

140141
self.include_metadata = include_metadata

Diff for: backend/onyx/indexing/embedder.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def __init__(
7373
def embed_chunks(
7474
self,
7575
chunks: list[DocAwareChunk],
76+
tenant_id: str | None = None,
77+
request_id: str | None = None,
7678
) -> list[IndexChunk]:
7779
raise NotImplementedError
7880

@@ -110,6 +112,8 @@ def __init__(
110112
def embed_chunks(
111113
self,
112114
chunks: list[DocAwareChunk],
115+
tenant_id: str | None = None,
116+
request_id: str | None = None,
113117
) -> list[IndexChunk]:
114118
"""Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well
115119
if they exist. If there is no space for it, it would have been thrown out at the chunking step.
@@ -143,6 +147,8 @@ def embed_chunks(
143147
texts=flat_chunk_texts,
144148
text_type=EmbedTextType.PASSAGE,
145149
large_chunks_present=large_chunks_present,
150+
tenant_id=tenant_id,
151+
request_id=request_id,
146152
)
147153

148154
chunk_titles = {
@@ -158,7 +164,10 @@ def embed_chunks(
158164
title_embed_dict: dict[str, Embedding] = {}
159165
if chunk_titles_list:
160166
title_embeddings = self.embedding_model.encode(
161-
chunk_titles_list, text_type=EmbedTextType.PASSAGE
167+
chunk_titles_list,
168+
text_type=EmbedTextType.PASSAGE,
169+
tenant_id=tenant_id,
170+
request_id=request_id,
162171
)
163172
title_embed_dict.update(
164173
{
@@ -190,7 +199,9 @@ def embed_chunks(
190199
"Title had to be embedded separately, this should not happen!"
191200
)
192201
title_embedding = self.embedding_model.encode(
193-
[title], text_type=EmbedTextType.PASSAGE
202+
[title],
203+
text_type=EmbedTextType.PASSAGE,
204+
request_id=request_id,
194205
)[0]
195206
title_embed_dict[title] = title_embedding
196207

@@ -231,14 +242,24 @@ def from_db_search_settings(
231242
def embed_chunks_with_failure_handling(
232243
chunks: list[DocAwareChunk],
233244
embedder: IndexingEmbedder,
245+
tenant_id: str | None = None,
246+
request_id: str | None = None,
234247
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
235248
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
236249
goes document by document to isolate the failure(s).
237250
"""
238251

252+
# TODO(rkuo): this doesn't disambiguate calls to the model server on retries.
253+
# Improve this if needed.
254+
239255
# First try to embed all chunks in one batch
240256
try:
241-
return embedder.embed_chunks(chunks=chunks), []
257+
return (
258+
embedder.embed_chunks(
259+
chunks=chunks, tenant_id=tenant_id, request_id=request_id
260+
),
261+
[],
262+
)
242263
except Exception:
243264
logger.exception("Failed to embed chunk batch. Trying individual docs.")
244265
# wait a couple seconds to let any rate limits or temporary issues resolve
@@ -254,7 +275,9 @@ def embed_chunks_with_failure_handling(
254275

255276
for doc_id, chunks_for_doc in chunks_by_doc.items():
256277
try:
257-
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
278+
doc_embedded_chunks = embedder.embed_chunks(
279+
chunks=chunks_for_doc, tenant_id=tenant_id, request_id=request_id
280+
)
258281
embedded_chunks.extend(doc_embedded_chunks)
259282
except Exception as e:
260283
logger.exception(f"Failed to embed chunks for document '{doc_id}'")

Diff for: backend/onyx/indexing/indexing_pipeline.py

+2
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,8 @@ def index_doc_batch(
791791
embed_chunks_with_failure_handling(
792792
chunks=chunks,
793793
embedder=embedder,
794+
tenant_id=tenant_id,
795+
request_id=index_attempt_metadata.request_id,
794796
)
795797
if chunks
796798
else ([], [])

Diff for: backend/onyx/natural_language_processing/search_nlp_models.py

+48-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable
44
from concurrent.futures import as_completed
55
from concurrent.futures import ThreadPoolExecutor
6+
from functools import partial
67
from functools import wraps
78
from typing import Any
89

@@ -114,10 +115,24 @@ def __init__(
114115
model_server_url = build_model_server_url(server_host, server_port)
115116
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
116117

117-
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
118+
def _make_model_server_request(
119+
self,
120+
embed_request: EmbedRequest,
121+
tenant_id: str | None = None,
122+
request_id: str | None = None,
123+
) -> EmbedResponse:
118124
def _make_request() -> Response:
125+
headers = {}
126+
if tenant_id:
127+
headers["X-Onyx-Tenant-ID"] = tenant_id
128+
129+
if request_id:
130+
headers["X-Onyx-Request-ID"] = request_id
131+
119132
response = requests.post(
120-
self.embed_server_endpoint, json=embed_request.model_dump()
133+
self.embed_server_endpoint,
134+
headers=headers,
135+
json=embed_request.model_dump(),
121136
)
122137
# signify that this is a rate limit error
123138
if response.status_code == 429:
@@ -165,6 +180,8 @@ def _batch_encode_texts(
165180
batch_size: int,
166181
max_seq_length: int,
167182
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
183+
tenant_id: str | None = None,
184+
request_id: str | None = None,
168185
) -> list[Embedding]:
169186
text_batches = batch_list(texts, batch_size)
170187

@@ -175,7 +192,11 @@ def _batch_encode_texts(
175192
embeddings: list[Embedding] = []
176193

177194
def process_batch(
178-
batch_idx: int, batch_len: int, text_batch: list[str]
195+
batch_idx: int,
196+
batch_len: int,
197+
text_batch: list[str],
198+
tenant_id: str | None = None,
199+
request_id: str | None = None,
179200
) -> tuple[int, list[Embedding]]:
180201
if self.callback:
181202
if self.callback.should_stop():
@@ -198,7 +219,9 @@ def process_batch(
198219
)
199220

200221
start_time = time.time()
201-
response = self._make_model_server_request(embed_request)
222+
response = self._make_model_server_request(
223+
embed_request, tenant_id=tenant_id, request_id=request_id
224+
)
202225
end_time = time.time()
203226

204227
processing_time = end_time - start_time
@@ -215,7 +238,16 @@ def process_batch(
215238
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
216239
with ThreadPoolExecutor(max_workers=num_threads) as executor:
217240
future_to_batch = {
218-
executor.submit(process_batch, idx, len(text_batches), batch): idx
241+
executor.submit(
242+
partial(
243+
process_batch,
244+
idx,
245+
len(text_batches),
246+
batch,
247+
tenant_id=tenant_id,
248+
request_id=request_id,
249+
)
250+
): idx
219251
for idx, batch in enumerate(text_batches, start=1)
220252
}
221253

@@ -238,7 +270,13 @@ def process_batch(
238270
else:
239271
# Original sequential processing
240272
for idx, text_batch in enumerate(text_batches, start=1):
241-
_, batch_embeddings = process_batch(idx, len(text_batches), text_batch)
273+
_, batch_embeddings = process_batch(
274+
idx,
275+
len(text_batches),
276+
text_batch,
277+
tenant_id=tenant_id,
278+
request_id=request_id,
279+
)
242280
embeddings.extend(batch_embeddings)
243281
if self.callback:
244282
self.callback.progress("_batch_encode_texts", 1)
@@ -253,6 +291,8 @@ def encode(
253291
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
254292
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
255293
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
294+
tenant_id: str | None = None,
295+
request_id: str | None = None,
256296
) -> list[Embedding]:
257297
if not texts or not all(texts):
258298
raise ValueError(f"Empty or missing text for embedding: {texts}")
@@ -284,6 +324,8 @@ def encode(
284324
text_type=text_type,
285325
batch_size=batch_size,
286326
max_seq_length=max_seq_length,
327+
tenant_id=tenant_id,
328+
request_id=request_id,
287329
)
288330

289331
@classmethod

Diff for: backend/onyx/utils/middleware.py

+14
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,23 @@
1111
from fastapi import Request
1212
from fastapi import Response
1313

14+
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
1415
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
1516

1617

18+
def add_onyx_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
19+
@app.middleware("http")
20+
async def set_tenant_id(
21+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
22+
) -> Response:
23+
"""Captures and sets the context var for the tenant."""
24+
25+
onyx_tenant_id = request.headers.get("X-Onyx-Tenant-ID")
26+
if onyx_tenant_id:
27+
CURRENT_TENANT_ID_CONTEXTVAR.set(onyx_tenant_id)
28+
return await call_next(request)
29+
30+
1731
def add_onyx_request_id_middleware(
1832
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
1933
) -> None:

0 commit comments

Comments
 (0)