Skip to content

Commit 54b883d

Browse files
authored
fix large docs selected in chat pruning (#4412)
* fix large docs selected in chat pruning * better approach to length restriction * comments * comments * fix unit tests and minor pruning bug * remove prints
1 parent 91faac5 commit 54b883d

File tree

4 files changed

+63
-29
lines changed

4 files changed

+63
-29
lines changed

backend/onyx/chat/process_message.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
4444
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
4545
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
46+
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
4647
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
4748
from onyx.configs.constants import BASIC_KEY
4849
from onyx.configs.constants import MessageType
@@ -692,8 +693,13 @@ def stream_chat_message_objects(
692693
doc_identifiers=identifier_tuples,
693694
document_index=document_index,
694695
)
696+
697+
# Add a maximum context size in the case of user-selected docs to prevent
698+
# slight inaccuracies in context window size pruning from causing
699+
# the entire query to fail
695700
document_pruning_config = DocumentPruningConfig(
696-
is_manually_selected_docs=True
701+
is_manually_selected_docs=True,
702+
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
697703
)
698704

699705
# In case the search doc is deleted, just don't include it

backend/onyx/chat/prune_and_merge.py

+44-12
Original file line numberDiff line numberDiff line change
@@ -312,45 +312,63 @@ def prune_sections(
312312
)
313313

314314

315-
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
315+
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
316316
assert (
317317
len(set([chunk.document_id for chunk in chunks])) == 1
318318
), "One distinct document must be passed into merge_doc_chunks"
319319

320+
ADJACENT_CHUNK_SEP = "\n"
321+
DISTANT_CHUNK_SEP = "\n\n...\n\n"
322+
320323
# Assuming there are no duplicates by this point
321324
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
322325

323326
center_chunk = max(
324327
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
325328
)
326329

330+
added_chars = 0
327331
merged_content = []
328332
for i, chunk in enumerate(sorted_chunks):
329333
if i > 0:
330334
prev_chunk_id = sorted_chunks[i - 1].chunk_id
331-
if chunk.chunk_id == prev_chunk_id + 1:
332-
merged_content.append("\n")
333-
else:
334-
merged_content.append("\n\n...\n\n")
335+
sep = (
336+
ADJACENT_CHUNK_SEP
337+
if chunk.chunk_id == prev_chunk_id + 1
338+
else DISTANT_CHUNK_SEP
339+
)
340+
merged_content.append(sep)
341+
added_chars += len(sep)
335342
merged_content.append(chunk.content)
336343

337344
combined_content = "".join(merged_content)
338345

339-
return InferenceSection(
340-
center_chunk=center_chunk,
341-
chunks=sorted_chunks,
342-
combined_content=combined_content,
346+
return (
347+
InferenceSection(
348+
center_chunk=center_chunk,
349+
chunks=sorted_chunks,
350+
combined_content=combined_content,
351+
),
352+
added_chars,
343353
)
344354

345355

346356
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
347357
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
348358
doc_order: dict[str, int] = {}
359+
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
360+
361+
# chunk de-duping and doc ordering
349362
for index, section in enumerate(sections):
350363
if section.center_chunk.document_id not in doc_order:
351364
doc_order[section.center_chunk.document_id] = index
365+
366+
combined_section_lengths[section.center_chunk.document_id] += len(
367+
section.combined_content
368+
)
369+
370+
chunks_map = docs_map[section.center_chunk.document_id]
352371
for chunk in [section.center_chunk] + section.chunks:
353-
chunks_map = docs_map[section.center_chunk.document_id]
354372
existing_chunk = chunks_map.get(chunk.chunk_id)
355373
if (
356374
existing_chunk is None
@@ -361,8 +379,22 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
361379
chunks_map[chunk.chunk_id] = chunk
362380

363381
new_sections = []
364-
for section_chunks in docs_map.values():
365-
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
382+
for doc_id, section_chunks in docs_map.items():
383+
section_chunks_list = list(section_chunks.values())
384+
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
385+
386+
previous_length = combined_section_lengths[doc_id] + added_chars
387+
# After merging, ensure the content respects the pruning done earlier. Each
388+
# combined section is restricted to the sum of the lengths of the sections
389+
# from the pruning step. Technically the correct approach would be to prune based
390+
# on tokens AGAIN, but this is a good approximation and worth not adding the
391+
# tokenization overhead. This could also be fixed if we added a way of removing
392+
# chunks from sections in the pruning step; at the moment this issue largely
393+
# exists because we only trim the final section's combined_content.
394+
merged_section.combined_content = merged_section.combined_content[
395+
:previous_length
396+
]
397+
new_sections.append(merged_section)
366398

367399
# Sort by highest score, then by original document order
368400
# It is now 1 large section per doc, the center chunk being the one with the highest score

backend/onyx/configs/chat_configs.py

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# ~3k input, half for docs, half for chat history + prompts
1717
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
1818

19+
# Maximum percentage of the context window to fill with selected sections
20+
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
21+
1922
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
2023
# Capped in Vespa at 0.5
2124
DOC_TIME_DECAY = float(

backend/tests/unit/onyx/chat/test_prune_and_merge.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from onyx.configs.constants import DocumentSource
55
from onyx.context.search.models import InferenceChunk
66
from onyx.context.search.models import InferenceSection
7+
from onyx.context.search.utils import inference_section_from_chunks
78

89

910
# This large test accounts for all of the following:
@@ -111,7 +112,7 @@ def create_inference_chunk(
111112
# Sections
112113
[
113114
# Document 1, top/middle/bot connected + disconnected section
114-
InferenceSection(
115+
inference_section_from_chunks(
115116
center_chunk=DOC_1_TOP_CHUNK,
116117
chunks=[
117118
DOC_1_FILLER_1,
@@ -120,9 +121,8 @@ def create_inference_chunk(
120121
DOC_1_MID_CHUNK,
121122
DOC_1_FILLER_3,
122123
],
123-
combined_content="N/A", # Not used
124124
),
125-
InferenceSection(
125+
inference_section_from_chunks(
126126
center_chunk=DOC_1_MID_CHUNK,
127127
chunks=[
128128
DOC_1_FILLER_2,
@@ -131,9 +131,8 @@ def create_inference_chunk(
131131
DOC_1_FILLER_3,
132132
DOC_1_FILLER_4,
133133
],
134-
combined_content="N/A",
135134
),
136-
InferenceSection(
135+
inference_section_from_chunks(
137136
center_chunk=DOC_1_BOTTOM_CHUNK,
138137
chunks=[
139138
DOC_1_FILLER_3,
@@ -142,9 +141,8 @@ def create_inference_chunk(
142141
DOC_1_FILLER_5,
143142
DOC_1_FILLER_6,
144143
],
145-
combined_content="N/A",
146144
),
147-
InferenceSection(
145+
inference_section_from_chunks(
148146
center_chunk=DOC_1_DISCONNECTED,
149147
chunks=[
150148
DOC_1_FILLER_7,
@@ -153,9 +151,8 @@ def create_inference_chunk(
153151
DOC_1_FILLER_9,
154152
DOC_1_FILLER_10,
155153
],
156-
combined_content="N/A",
157154
),
158-
InferenceSection(
155+
inference_section_from_chunks(
159156
center_chunk=DOC_2_TOP_CHUNK,
160157
chunks=[
161158
DOC_2_FILLER_1,
@@ -164,9 +161,8 @@ def create_inference_chunk(
164161
DOC_2_FILLER_3,
165162
DOC_2_BOTTOM_CHUNK,
166163
],
167-
combined_content="N/A",
168164
),
169-
InferenceSection(
165+
inference_section_from_chunks(
170166
center_chunk=DOC_2_BOTTOM_CHUNK,
171167
chunks=[
172168
DOC_2_TOP_CHUNK,
@@ -175,7 +171,6 @@ def create_inference_chunk(
175171
DOC_2_FILLER_4,
176172
DOC_2_FILLER_5,
177173
],
178-
combined_content="N/A",
179174
),
180175
],
181176
# Expected Content
@@ -204,15 +199,13 @@ def test_merge_sections(
204199
(
205200
# Sections
206201
[
207-
InferenceSection(
202+
inference_section_from_chunks(
208203
center_chunk=DOC_1_TOP_CHUNK,
209204
chunks=[DOC_1_TOP_CHUNK],
210-
combined_content="N/A", # Not used
211205
),
212-
InferenceSection(
206+
inference_section_from_chunks(
213207
center_chunk=DOC_1_MID_CHUNK,
214208
chunks=[DOC_1_MID_CHUNK],
215-
combined_content="N/A",
216209
),
217210
],
218211
# Expected Content

0 commit comments

Comments
 (0)