Skip to content

Commit 00b2f7a

Browse files
authored
Merge pull request #40 from ls1intum/optimize-reranking
Optimize reranking
2 parents 82093be + 069e45e commit 00b2f7a

File tree

7 files changed

+351
-310
lines changed

7 files changed

+351
-310
lines changed

rag/app/api/question_router.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,21 @@ async def ask(request: UserRequest):
1717

1818
if not question or not classification:
1919
raise HTTPException(status_code=400, detail="No question or classification provided")
20+
21+
answer = request_handler.handle_question(question, classification, language, org_id=org_id)
22+
return {"answer": answer}
2023

21-
if config.TEST_MODE == "true":
22-
answer, used_tokens, general_context, specific_context, sq_context = request_handler.handle_question_test_mode(question,
23-
classification,
24-
language,
25-
org_id=org_id)
26-
return {"answer": answer, "used_tokens": used_tokens, "general_context": general_context,
27-
"specific_context": specific_context, "sq_context": sq_context}
28-
else:
29-
answer = request_handler.handle_question(question, classification, language, org_id=org_id)
30-
return {"answer": answer}
24+
# Uncomment to use test mode and calculate RAG metrics
25+
# if config.TEST_MODE == "true":
26+
# answer, used_tokens, general_context, specific_context, sq_context = request_handler.handle_question_test_mode(question,
27+
# classification,
28+
# language,
29+
# org_id=org_id)
30+
# return {"answer": answer, "used_tokens": used_tokens, "general_context": general_context,
31+
# "specific_context": specific_context, "sq_context": sq_context}
32+
# else:
33+
# answer = request_handler.handle_question(question, classification, language, org_id=org_id)
34+
# return {"answer": answer}
3135

3236

3337
@question_router.post("/chat", tags=["chatbot"], dependencies=[Depends(auth_handler.verify_api_key)])

rag/app/managers/request_handler.py

Lines changed: 158 additions & 107 deletions
Large diffs are not rendered by default.

rag/app/managers/weaviate_manager.py

Lines changed: 45 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from enum import Enum
3-
from typing import List, Union, Tuple, Optional
3+
from typing import List, Union, Tuple, Optional, Dict
44

55
import weaviate
66
import weaviate.classes as wvc
@@ -239,28 +239,23 @@ def get_question_embedding(self, question: str) -> List[float]:
239239
question_embedding = self.model.embed(question)
240240
return question_embedding
241241

242-
def get_relevant_context(self, question: str, question_embedding: str, study_program: str, language: str, org_id: Optional[int],
243-
test_mode: bool = False,
244-
limit=10, top_n=5, filter_by_org: bool = True) -> Union[str, Tuple[str, List[str]]]:
242+
243+
def get_relevant_context(self, question_embedding: List[float], study_program: str, org_id: Optional[int],
244+
limit=10, filter_by_org: bool = True) -> List[Dict]:
245245
"""
246-
Retrieve relevant documents based on the question embedding and study program.
247-
Optionally returns both the concatenated context and the sorted context list for testing purposes.
246+
Retrieves relevant context documents based on the given question embedding and study program.
248247
249248
Args:
250-
question (str): The student's question.
251-
study_program (str): The study program of the student.
252-
keywords (str, optional): Extracted keywords for boosting. Defaults to None.
253-
test_mode (bool, optional): If True, returns both context and sorted_context. Defaults to False.
249+
question_embedding (List[float]): The vector embedding representing the student's question.
250+
study_program (str): The name of the study program to filter documents.
251+
org_id (Optional[int]): The organization ID to filter documents (if applicable).
252+
limit (int, optional): The maximum number of documents to retrieve. Defaults to 10.
253+
filter_by_org (bool, optional): Whether to filter results by organization ID. Defaults to True.
254254
255255
Returns:
256-
Union[str, Tuple[str, List[str]]]:
257-
- If test_mode is False: Returns the concatenated context string.
258-
- If test_mode is True: Returns a tuple of (context, sorted_context list).
256+
List[Dict]: A list of document dictionaries relevant to the query.
259257
"""
260258
try:
261-
# Define the number of documents to retrieve
262-
min_relevance_score = 0.35
263-
264259
# Normalize the study program name
265260
study_program = WeaviateManager.normalize_study_program_name(study_program)
266261

@@ -273,11 +268,6 @@ def get_relevant_context(self, question: str, question_embedding: str, study_pro
273268
else:
274269
filters = Filter.by_property(DocumentSchema.STUDY_PROGRAMS.value).contains_any([study_program])
275270

276-
# If getting general context, adjust the parameters
277-
if study_program.lower() != "general":
278-
limit = 10
279-
min_relevance_score = 0.25
280-
281271

282272
# Perform the vector-based query with filters
283273
query_result = self.documents.query.near_vector(
@@ -295,51 +285,28 @@ def get_relevant_context(self, question: str, question_embedding: str, study_pro
295285
}
296286
for result in query_result.objects
297287
]
298-
content_content_list: List[str] = [doc['content'] for doc in context_list]
299-
300-
# Remove exact duplicates from context_list
301-
content_content_list = WeaviateManager.remove_exact_duplicates(content_content_list)
302-
303-
# Rerank the unique contexts using Cohere
304-
sorted_context = self.reranker.rerank_with_cohere(context_list=content_content_list, query=question,
305-
language=language,
306-
min_relevance_score=min_relevance_score, top_n=top_n)
307-
# Integrate links
308-
sorted_context_with_links = []
309-
for sorted_content in sorted_context:
310-
for doc in context_list:
311-
if doc['content'] == sorted_content:
312-
if doc['link']:
313-
sorted_context_with_links.append(f'Link: {doc["link"]}\nContent: {doc["content"]}')
314-
else:
315-
sorted_context_with_links.append(f'Link: -\nContent: {doc["content"]}')
316-
break
317-
318-
context = "\n-----\n".join(sorted_context_with_links)
319-
320-
# Return based on test_mode
321-
if test_mode:
322-
return context, sorted_context_with_links
323-
else:
324-
return context
288+
289+
return context_list
325290

326291
except Exception as e:
327292
logging.error(f"Error retrieving relevant context: {e}")
328293
# tb = traceback.format_exc()
329294
# logging.error("Traceback:\n%s", tb)
330-
return "" if not test_mode else ("", [])
295+
return []
331296

332-
def get_relevant_sample_questions(self, question: str, question_embedding: str, language: str, org_id: int) -> List[SampleQuestion]:
297+
298+
def get_relevant_sample_questions(self, question: str, question_embedding: List[float], language: str, org_id: int) -> List[SampleQuestion]:
333299
"""
334-
Retrieve relevant sample questions and answers based on the question embedding.
300+
Retrieves relevant sample questions and their answers based on the provided question and its embedding.
335301
336302
Args:
337-
question (str): The student's question.
303+
question (str): The original student question.
304+
question_embedding (List[float]): The vector embedding of the question.
338305
language (str): The language of the question.
339-
top_k (int): The number of top relevant sample questions to return.
306+
org_id (int): The organization ID to filter sample questions.
340307
341308
Returns:
342-
List[SampleQuestion]: A list of SampleQuestion objects, sorted based on reranking results.
309+
List[SampleQuestion]: A list of SampleQuestion objects, sorted by relevance.
343310
"""
344311
try:
345312
limit = 5
@@ -364,19 +331,22 @@ def get_relevant_sample_questions(self, question: str, question_embedding: str,
364331
study_programs=study_programs))
365332

366333
# Rerank the sample questions using the reranker
367-
context_list = [sq.question for sq in sample_questions]
368-
sorted_questions = self.reranker.rerank_with_cohere(
334+
context_list = [
335+
(f"Question: {sq.question}\nAnswer: {sq.answer}" if language == "English"
336+
else f"Frage: {sq.question}\nAntwort: {sq.answer}")
337+
for sq in sample_questions
338+
]
339+
340+
rerank_results = self.reranker.rerank_with_cohere(
369341
context_list=context_list, query=question, language=language, top_n=top_n,
370-
min_relevance_score=min_relevance_score
371342
)
372343

373-
# Map the sorted questions back to SampleQuestion objects
374344
sorted_sample_questions: List[SampleQuestion] = []
375-
for sorted_question in sorted_questions:
376-
for sq in sample_questions:
377-
if sq.question == sorted_question:
378-
sorted_sample_questions.append(sq)
379-
break
345+
for result in rerank_results:
346+
idx = result['index']
347+
score = result['relevance_score']
348+
if score >= min_relevance_score and idx < len(sample_questions):
349+
sorted_sample_questions.append(sample_questions[idx])
380350

381351
return sorted_sample_questions
382352

@@ -673,3 +643,15 @@ def remove_exact_duplicates(context_list: List[str]) -> List[str]:
673643
unique_context.append(context)
674644
seen.add(context)
675645
return unique_context
646+
647+
@staticmethod
648+
def remove_exact_duplicates_from_dict(dicts: List[Dict], key: str = 'content') -> list:
649+
"""Remove dicts with duplicate values for given key, preserving order."""
650+
seen = set()
651+
deduped = []
652+
for d in dicts:
653+
val = d.get(key)
654+
if val not in seen:
655+
deduped.append(d)
656+
seen.add(val)
657+
return deduped

rag/app/post_retrieval/reranker.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import cohere
44
from app.models.base_model import BaseModelClient
55
from sklearn.metrics.pairwise import cosine_similarity
6-
from typing import List
6+
from typing import List, Dict
77
import requests
88

99
class DocumentWithEmbedding:
@@ -64,23 +64,22 @@ def rerank_with_embeddings(self, context_list: List[DocumentWithEmbedding], keyw
6464

6565
return ranked_context_list
6666

67-
def rerank_with_cohere(self, context_list: List[str], query: str, language: str, min_relevance_score: float, top_n: int = 5) -> List[str]:
67+
def rerank_with_cohere(self, context_list: List[str], query: str, language: str, top_n: int = 5) -> List[Dict]:
6868
"""
6969
Re-ranks the context list using the Cohere reranking model deployed on Azure.
7070
7171
Args:
7272
context_list (List[str]): List of document texts to be re-ranked.
7373
query (str): The query string to rerank the documents against.
7474
language (str): The language of the documents ('english' or other).
75-
min_relevance_score (float): The minimum relevance score to consider.
7675
top_n (int): The number of top results to return after re-ranking.
7776
7877
Returns:
79-
List[str]: A list of the re-ranked document contents based on relevance.
78+
List[Dict]: A list of the re-ranked document contents based on relevance.
8079
"""
8180
try:
82-
if len(context_list) == 0:
83-
return context_list
81+
if not context_list:
82+
return []
8483

8584
# Determine the correct endpoint URL and API key based on language
8685
if language.lower() == "english":
@@ -103,30 +102,13 @@ def rerank_with_cohere(self, context_list: List[str], query: str, language: str,
103102
}
104103

105104
response = requests.post(rerank_url, headers=headers, json=payload)
106-
107105
if response.status_code != 200:
108106
logging.error(f"Error during Cohere re-ranking: {response.status_code} {response.text}")
109-
return context_list[:top_n]
110-
111-
response_json = response.json()
112-
113-
# Log the full response from the API for debugging
114-
results = response_json.get('results', [])
115-
116-
# Log the ranked documents that are in the top_n
117-
ranked_indices = []
118-
for i, result in enumerate(results):
119-
index = result['index']
120-
relevance_score = result.get('relevance_score')
121-
# Filter results based on min_relevance_score
122-
if relevance_score >= min_relevance_score:
123-
ranked_indices.append(index)
124-
125-
# Get the ranked documents based on the indices
126-
ranked_context_list = [context_list[result['index']] for result in results]
107+
return [{'index': i, 'relevance_score': 1.0} for i in range(min(top_n, len(context_list)))]
127108

128-
return ranked_context_list
109+
results = response.json().get('results', [])
110+
return results
129111

130112
except Exception as e:
131113
logging.error(f"Error during Cohere re-ranking: {e}")
132-
return context_list[:top_n]
114+
return [{'index': i, 'relevance_score': 1.0} for i in range(min(top_n, len(context_list)))]

0 commit comments

Comments
 (0)