11import logging
22from enum import Enum
3- from typing import List , Union , Tuple , Optional
3+ from typing import List , Union , Tuple , Optional , Dict
44
55import weaviate
66import weaviate .classes as wvc
@@ -239,28 +239,23 @@ def get_question_embedding(self, question: str) -> List[float]:
239239 question_embedding = self .model .embed (question )
240240 return question_embedding
241241
242- def get_relevant_context ( self , question : str , question_embedding : str , study_program : str , language : str , org_id : Optional [ int ],
243- test_mode : bool = False ,
244- limit = 10 , top_n = 5 , filter_by_org : bool = True ) -> Union [ str , Tuple [ str , List [str ]] ]:
242+
243+ def get_relevant_context ( self , question_embedding : List [ float ], study_program : str , org_id : Optional [ int ] ,
244+ limit = 10 , filter_by_org : bool = True ) -> List [Dict ]:
245245 """
246- Retrieve relevant documents based on the question embedding and study program.
247- Optionally returns both the concatenated context and the sorted context list for testing purposes.
246+ Retrieves relevant context documents based on the given question embedding and study program.
248247
249248 Args:
250- question (str): The student's question.
251- study_program (str): The study program of the student.
252- keywords (str, optional): Extracted keywords for boosting. Defaults to None.
253- test_mode (bool, optional): If True, returns both context and sorted_context. Defaults to False.
249+ question_embedding (List[float]): The vector embedding representing the student's question.
250+ study_program (str): The name of the study program to filter documents.
251+ org_id (Optional[int]): The organization ID to filter documents (if applicable).
252+ limit (int, optional): The maximum number of documents to retrieve. Defaults to 10.
253+ filter_by_org (bool, optional): Whether to filter results by organization ID. Defaults to True.
254254
255255 Returns:
256- Union[str, Tuple[str, List[str]]]:
257- - If test_mode is False: Returns the concatenated context string.
258- - If test_mode is True: Returns a tuple of (context, sorted_context list).
256+ List[Dict]: A list of document dictionaries relevant to the query.
259257 """
260258 try :
261- # Define the number of documents to retrieve
262- min_relevance_score = 0.35
263-
264259 # Normalize the study program name
265260 study_program = WeaviateManager .normalize_study_program_name (study_program )
266261
@@ -273,11 +268,6 @@ def get_relevant_context(self, question: str, question_embedding: str, study_pro
273268 else :
274269 filters = Filter .by_property (DocumentSchema .STUDY_PROGRAMS .value ).contains_any ([study_program ])
275270
276- # If getting general context, adjust the parameters
277- if study_program .lower () != "general" :
278- limit = 10
279- min_relevance_score = 0.25
280-
281271
282272 # Perform the vector-based query with filters
283273 query_result = self .documents .query .near_vector (
@@ -295,51 +285,28 @@ def get_relevant_context(self, question: str, question_embedding: str, study_pro
295285 }
296286 for result in query_result .objects
297287 ]
298- content_content_list : List [str ] = [doc ['content' ] for doc in context_list ]
299-
300- # Remove exact duplicates from context_list
301- content_content_list = WeaviateManager .remove_exact_duplicates (content_content_list )
302-
303- # Rerank the unique contexts using Cohere
304- sorted_context = self .reranker .rerank_with_cohere (context_list = content_content_list , query = question ,
305- language = language ,
306- min_relevance_score = min_relevance_score , top_n = top_n )
307- # Integrate links
308- sorted_context_with_links = []
309- for sorted_content in sorted_context :
310- for doc in context_list :
311- if doc ['content' ] == sorted_content :
312- if doc ['link' ]:
313- sorted_context_with_links .append (f'Link: { doc ["link" ]} \n Content: { doc ["content" ]} ' )
314- else :
315- sorted_context_with_links .append (f'Link: -\n Content: { doc ["content" ]} ' )
316- break
317-
318- context = "\n -----\n " .join (sorted_context_with_links )
319-
320- # Return based on test_mode
321- if test_mode :
322- return context , sorted_context_with_links
323- else :
324- return context
288+
289+ return context_list
325290
326291 except Exception as e :
327292 logging .error (f"Error retrieving relevant context: { e } " )
328293 # tb = traceback.format_exc()
329294 # logging.error("Traceback:\n%s", tb)
330- return "" if not test_mode else ( "" , [])
295+ return []
331296
332- def get_relevant_sample_questions (self , question : str , question_embedding : str , language : str , org_id : int ) -> List [SampleQuestion ]:
297+
298+ def get_relevant_sample_questions (self , question : str , question_embedding : List [float ], language : str , org_id : int ) -> List [SampleQuestion ]:
333299 """
334- Retrieve relevant sample questions and answers based on the question embedding.
300+ Retrieves relevant sample questions and their answers based on the provided question and its embedding.
335301
336302 Args:
337- question (str): The student's question.
303+ question (str): The original student question.
304+ question_embedding (List[float]): The vector embedding of the question.
338305 language (str): The language of the question.
339- top_k (int): The number of top relevant sample questions to return .
306+ org_id (int): The organization ID to filter sample questions.
340307
341308 Returns:
342- List[SampleQuestion]: A list of SampleQuestion objects, sorted based on reranking results .
309+ List[SampleQuestion]: A list of SampleQuestion objects, sorted by relevance .
343310 """
344311 try :
345312 limit = 5
@@ -364,19 +331,22 @@ def get_relevant_sample_questions(self, question: str, question_embedding: str,
364331 study_programs = study_programs ))
365332
366333 # Rerank the sample questions using the reranker
367- context_list = [sq .question for sq in sample_questions ]
368- sorted_questions = self .reranker .rerank_with_cohere (
334+ context_list = [
335+ (f"Question: { sq .question } \n Answer: { sq .answer } " if language == "English"
336+ else f"Frage: { sq .question } \n Antwort: { sq .answer } " )
337+ for sq in sample_questions
338+ ]
339+
340+ rerank_results = self .reranker .rerank_with_cohere (
369341 context_list = context_list , query = question , language = language , top_n = top_n ,
370- min_relevance_score = min_relevance_score
371342 )
372343
373- # Map the sorted questions back to SampleQuestion objects
374344 sorted_sample_questions : List [SampleQuestion ] = []
375- for sorted_question in sorted_questions :
376- for sq in sample_questions :
377- if sq . question == sorted_question :
378- sorted_sample_questions . append ( sq )
379- break
345+ for result in rerank_results :
346+ idx = result [ 'index' ]
347+ score = result [ 'relevance_score' ]
348+ if score >= min_relevance_score and idx < len ( sample_questions ):
349+ sorted_sample_questions . append ( sample_questions [ idx ])
380350
381351 return sorted_sample_questions
382352
@@ -673,3 +643,15 @@ def remove_exact_duplicates(context_list: List[str]) -> List[str]:
673643 unique_context .append (context )
674644 seen .add (context )
675645 return unique_context
646+
647+ @staticmethod
648+ def remove_exact_duplicates_from_dict (dicts : List [Dict ], key : str = 'content' ) -> list :
649+ """Remove dicts with duplicate values for given key, preserving order."""
650+ seen = set ()
651+ deduped = []
652+ for d in dicts :
653+ val = d .get (key )
654+ if val not in seen :
655+ deduped .append (d )
656+ seen .add (val )
657+ return deduped
0 commit comments