@@ -312,45 +312,63 @@ def prune_sections(
312
312
)
313
313
314
314
315
- def _merge_doc_chunks (chunks : list [InferenceChunk ]) -> InferenceSection :
315
+ def _merge_doc_chunks (chunks : list [InferenceChunk ]) -> tuple [ InferenceSection , int ] :
316
316
assert (
317
317
len (set ([chunk .document_id for chunk in chunks ])) == 1
318
318
), "One distinct document must be passed into merge_doc_chunks"
319
319
320
+ ADJACENT_CHUNK_SEP = "\n "
321
+ DISTANT_CHUNK_SEP = "\n \n ...\n \n "
322
+
320
323
# Assuming there are no duplicates by this point
321
324
sorted_chunks = sorted (chunks , key = lambda x : x .chunk_id )
322
325
323
326
center_chunk = max (
324
327
chunks , key = lambda x : x .score if x .score is not None else float ("-inf" )
325
328
)
326
329
330
+ added_chars = 0
327
331
merged_content = []
328
332
for i , chunk in enumerate (sorted_chunks ):
329
333
if i > 0 :
330
334
prev_chunk_id = sorted_chunks [i - 1 ].chunk_id
331
- if chunk .chunk_id == prev_chunk_id + 1 :
332
- merged_content .append ("\n " )
333
- else :
334
- merged_content .append ("\n \n ...\n \n " )
335
+ sep = (
336
+ ADJACENT_CHUNK_SEP
337
+ if chunk .chunk_id == prev_chunk_id + 1
338
+ else DISTANT_CHUNK_SEP
339
+ )
340
+ merged_content .append (sep )
341
+ added_chars += len (sep )
335
342
merged_content .append (chunk .content )
336
343
337
344
combined_content = "" .join (merged_content )
338
345
339
- return InferenceSection (
340
- center_chunk = center_chunk ,
341
- chunks = sorted_chunks ,
342
- combined_content = combined_content ,
346
+ return (
347
+ InferenceSection (
348
+ center_chunk = center_chunk ,
349
+ chunks = sorted_chunks ,
350
+ combined_content = combined_content ,
351
+ ),
352
+ added_chars ,
343
353
)
344
354
345
355
346
356
def _merge_sections (sections : list [InferenceSection ]) -> list [InferenceSection ]:
347
357
docs_map : dict [str , dict [int , InferenceChunk ]] = defaultdict (dict )
348
358
doc_order : dict [str , int ] = {}
359
+ combined_section_lengths : dict [str , int ] = defaultdict (lambda : 0 )
360
+
361
+ # chunk de-duping and doc ordering
349
362
for index , section in enumerate (sections ):
350
363
if section .center_chunk .document_id not in doc_order :
351
364
doc_order [section .center_chunk .document_id ] = index
365
+
366
+ combined_section_lengths [section .center_chunk .document_id ] += len (
367
+ section .combined_content
368
+ )
369
+
370
+ chunks_map = docs_map [section .center_chunk .document_id ]
352
371
for chunk in [section .center_chunk ] + section .chunks :
353
- chunks_map = docs_map [section .center_chunk .document_id ]
354
372
existing_chunk = chunks_map .get (chunk .chunk_id )
355
373
if (
356
374
existing_chunk is None
@@ -361,8 +379,22 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
361
379
chunks_map [chunk .chunk_id ] = chunk
362
380
363
381
new_sections = []
364
- for section_chunks in docs_map .values ():
365
- new_sections .append (_merge_doc_chunks (chunks = list (section_chunks .values ())))
382
+ for doc_id , section_chunks in docs_map .items ():
383
+ section_chunks_list = list (section_chunks .values ())
384
+ merged_section , added_chars = _merge_doc_chunks (chunks = section_chunks_list )
385
+
386
+ previous_length = combined_section_lengths [doc_id ] + added_chars
387
+ # After merging, ensure the content respects the pruning done earlier. Each
388
+ # combined section is restricted to the sum of the lengths of the sections
389
+ # from the pruning step. Technically the correct approach would be to prune based
390
+ # on tokens AGAIN, but this is a good approximation and worth not adding the
391
+ # tokenization overhead. This could also be fixed if we added a way of removing
392
+ # chunks from sections in the pruning step; at the moment this issue largely
393
+ # exists because we only trim the final section's combined_content.
394
+ merged_section .combined_content = merged_section .combined_content [
395
+ :previous_length
396
+ ]
397
+ new_sections .append (merged_section )
366
398
367
399
# Sort by highest score, then by original document order
368
400
# It is now 1 large section per doc, the center chunk being the one with the highest score
0 commit comments