11import json
2- from json import JSONDecodeError
2+ import logging
33
44from langchain_core .language_models import BaseChatModel
55from langchain_core .messages import HumanMessage , SystemMessage
6+ from pydantic import BaseModel , ValidationError
7+
68from commons .model import LLMQueryResponse , LLMScoreResponse , Document
7- import logging
9+ from commons .model .query_schema import create_queries_schema
10+ from commons .model .score_schema import BinaryScore , GradedScore
811
912log = logging .getLogger (__name__ )
1013
@@ -16,13 +19,14 @@ def __init__(self, chat_model: BaseChatModel):
1619 def generate_queries (self , document : Document , num_queries_generate_per_doc : int ) -> LLMQueryResponse :
1720 """
1821 Generate queries based on the given document and num_queries_generate_per_doc and
19- Returns a list of generated queries or just a generated string in case of LLM hallucination
22+ Returns a list of generated `num_queries_generate_per_doc` queries or throws an exception if LLM hallucinates
2023 """
24+ schema : type [BaseModel ] = create_queries_schema (num_queries_generate_per_doc )
25+
2126 system_prompt = (
2227 f"You are a helpful assistant! Generate { num_queries_generate_per_doc } "
23- "queries based on the given document below. "
24- "**Output only** a JSON array of strings—nothing else. "
25- "Example format: [\" first query\" , \" second query\" ]"
28+ "natural language search queries based strictly on the given document."
29+ "Avoid duplicates. Return a structured object matching the provided schema."
2630 )
2731
2832 doc_json = document .model_dump_json (exclude = {"is_used_to_generate_queries" })
@@ -32,54 +36,49 @@ def generate_queries(self, document: Document, num_queries_generate_per_doc: int
3236 HumanMessage (content = f"Document:\n { doc_json } " )
3337 ]
3438
35- # The response from invoke is an AIMessage object which contains all the needed info
36- response = self .chat_model .invoke (messages )
37- response_content = response .content
38- if not isinstance (response_content , str ):
39- response_content = json .dumps (response_content )
40-
39+ # Use LangChain structured output
40+ structured_llm = self .chat_model .with_structured_output (schema )
4141 try :
42- output = LLMQueryResponse ( response_content = response_content )
43- except (KeyError , JSONDecodeError , ValueError ) as e :
44- log .warning ( f" LLM unexpected response. Raw output: { response . content } " )
42+ model_response = structured_llm . invoke ( messages )
43+ except (ValidationError , KeyError ) as e :
44+ log .debug ( "Invalid LLM response." )
4545 raise ValueError (f"Invalid LLM response: { e } " )
4646
47- return output
47+ # Remove duplicate generated-queries
48+ seen = set ()
49+ unique_queries : list [str ] = []
50+ for query in model_response .queries :
51+ if query not in seen :
52+ seen .add (query )
53+ unique_queries .append (query )
54+ unique_queries_len = len (unique_queries )
55+ if unique_queries_len != num_queries_generate_per_doc :
56+ log .info (f"Expected { num_queries_generate_per_doc } unique queries, got { unique_queries_len } " )
57+
58+ return LLMQueryResponse (response_content = json .dumps (unique_queries ))
4859
4960 def generate_score (self , document : Document , query : str , relevance_scale : str ,
5061 explanation : bool = False ) -> LLMScoreResponse :
5162 """
5263 Generates a relevance score for a given document-query pair using a specified relevance scale.
5364 If explanation flag is set to true, score explanation is generated as well.
5465 """
55- if relevance_scale == "binary" :
56- description = (" - 0: the query is NOT relevant to the given document\n "
57- " - 1: the query is relevant to the given document" )
58- elif relevance_scale == "graded" :
59- description = (" - 0: the query is NOT relevant to the given document\n "
60- " - 1: the query may be relevant to the given document\n "
61- " - 2: the document proposed is the answer to the query" )
62- else :
63- msg = f"Invalid relevance scale: { relevance_scale } "
64- log .error (msg )
65- raise ValueError (msg )
66+ if relevance_scale not in {"binary" , "graded" }:
67+ raise ValueError (f"Invalid relevance scale: { relevance_scale } " )
68+
69+ schema : type [BaseModel ] = BinaryScore if relevance_scale == "binary" else GradedScore
6670
6771 system_prompt = (f"You are a professional data labeler and, given a document with a set of fields and a query "
6872 f"and you need to return the relevance score in a scale called { relevance_scale .upper ()} . "
69- f"The scores of this scale are built as follows:\n { description } \n " )
70-
73+ " Return a structured object matching the provided schema." )
7174 if explanation :
7275 system_prompt += (
73- "Return ONLY a **valid JSON** object with two keys:"
74- " `score`: the related score as an integer value\n "
75- " `explanation`: your concise explanation for that score\n "
76- "As an example, I expect a JSON response like the following: "
77- "{\" score\" : \" integer value\" ,\" explanation\" : \" I rated this score because...\" }"
76+ " Include a clear explanation justifying your score "
77+ "in the `explanation` field based on the provided schema."
7878 )
7979 else :
8080 system_prompt += (
81- "Return ONLY a **valid JSON** object with key 'score' and the related score as an integer value."
82- "I expect a JSON response like the following: {\" score\" : \" integer value\" }"
81+ " Do not include any explanation."
8382 )
8483
8584 messages = [
@@ -92,24 +91,16 @@ def generate_score(self, document: Document, query: str, relevance_scale: str,
9291 )
9392 ]
9493
95- response_content = self .chat_model .invoke (messages ).content
96- if isinstance (response_content , str ):
97- raw = response_content .strip ()
98- else :
99- raw = json .dumps (response_content )
100-
94+ # Use LangChain structured output
95+ structured_llm = self .chat_model .with_structured_output (schema )
10196 try :
102- score = json .loads (raw )['score' ]
103- score_explanation = None
104- if explanation :
105- score_explanation = json .loads (raw )['explanation' ]
106- except (JSONDecodeError , KeyError ) as e :
107- log .debug (f"LLM unexpected response. Raw output: { raw } " )
97+ model_response = structured_llm .invoke (messages )
98+ except (ValidationError , KeyError ) as e :
99+ log .debug ("Invalid LLM response." )
108100 raise ValueError (f"Invalid LLM response: { e } " )
109101
110- try :
111- parsed = LLMScoreResponse (score = score , scale = relevance_scale , explanation = score_explanation )
112- return parsed
113- except ValueError as e :
114- log .warning (f"Validation error for score '{ score } ' on scale '{ relevance_scale } ': { e } " )
115- raise e
102+ return LLMScoreResponse (
103+ score = model_response .score ,
104+ scale = relevance_scale ,
105+ explanation = (model_response .explanation if explanation else None )
106+ )
0 commit comments