Skip to content

Commit 2b00efe

Browse files
committed
DAGE-90: Support structured output in LLM service
1 parent c9b2cdb commit 2b00efe

File tree

7 files changed

+168
-96
lines changed

7 files changed

+168
-96
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Type
2+
3+
from pydantic import BaseModel, Field, create_model, conlist, constr
4+
5+
6+
def create_queries_schema(num_queries_generate: int) -> Type[BaseModel]:
7+
"""
8+
Returns a Pydantic model that enforces `queries` to be a list of exactly
9+
`num_queries_generate` and non-empty strings. Used to validate LLM output.
10+
"""
11+
cleaned_query = constr(strip_whitespace=True, min_length=1)
12+
13+
exact_num_queries = conlist(cleaned_query, min_length=num_queries_generate,
14+
max_length=num_queries_generate)
15+
16+
schema = create_model(
17+
"LLMQueries",
18+
queries=(exact_num_queries, Field(...,
19+
description=f"Return exactly {num_queries_generate} "
20+
f"distinct queries as plain strings.")),
21+
)
22+
return schema
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional, Literal
2+
from pydantic import BaseModel, Field
3+
4+
5+
class BinaryScore(BaseModel):
6+
"""Returns a binary relevance score."""
7+
score: Literal[0, 1] = Field(..., description="0 = not relevant, 1 = relevant")
8+
explanation: Optional[str] = Field(None, description="Explanation for why this score")
9+
10+
11+
class GradedScore(BaseModel):
12+
"""Returns a graded relevance score."""
13+
score: Literal[0, 1, 2] = Field(..., description="0 = not relevant, 1 = maybe, 2 = is the answer")
14+
explanation: Optional[str] = Field(None, description="Explanation for why this score")
Lines changed: 45 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import json
2-
from json import JSONDecodeError
2+
import logging
33

44
from langchain_core.language_models import BaseChatModel
55
from langchain_core.messages import HumanMessage, SystemMessage
6+
from pydantic import BaseModel, ValidationError
7+
68
from commons.model import LLMQueryResponse, LLMScoreResponse, Document
7-
import logging
9+
from commons.model.query_schema import create_queries_schema
10+
from commons.model.score_schema import BinaryScore, GradedScore
811

912
log = logging.getLogger(__name__)
1013

@@ -16,13 +19,14 @@ def __init__(self, chat_model: BaseChatModel):
1619
def generate_queries(self, document: Document, num_queries_generate_per_doc: int) -> LLMQueryResponse:
1720
"""
1821
Generate queries based on the given document and num_queries_generate_per_doc and
19-
Returns a list of generated queries or just a generated string in case of LLM hallucination
22+
Returns a list of generated `num_queries_generate_per_doc` queries or throws an exception if LLM hallucinates
2023
"""
24+
schema: type[BaseModel] = create_queries_schema(num_queries_generate_per_doc)
25+
2126
system_prompt = (
2227
f"You are a helpful assistant! Generate {num_queries_generate_per_doc} "
23-
"queries based on the given document below. "
24-
"**Output only** a JSON array of strings—nothing else. "
25-
"Example format: [\"first query\", \"second query\"]"
28+
"natural language search queries based strictly on the given document."
29+
"Avoid duplicates. Return a structured object matching the provided schema."
2630
)
2731

2832
doc_json = document.model_dump_json(exclude={"is_used_to_generate_queries"})
@@ -32,54 +36,49 @@ def generate_queries(self, document: Document, num_queries_generate_per_doc: int
3236
HumanMessage(content=f"Document:\n{doc_json}")
3337
]
3438

35-
# The response from invoke is an AIMessage object which contains all the needed info
36-
response = self.chat_model.invoke(messages)
37-
response_content = response.content
38-
if not isinstance(response_content, str):
39-
response_content = json.dumps(response_content)
40-
39+
# Use LangChain structured output
40+
structured_llm = self.chat_model.with_structured_output(schema)
4141
try:
42-
output = LLMQueryResponse(response_content=response_content)
43-
except (KeyError, JSONDecodeError, ValueError) as e:
44-
log.warning(f"LLM unexpected response. Raw output: {response.content}")
42+
model_response = structured_llm.invoke(messages)
43+
except (ValidationError, KeyError) as e:
44+
log.debug("Invalid LLM response.")
4545
raise ValueError(f"Invalid LLM response: {e}")
4646

47-
return output
47+
# Remove duplicate generated-queries
48+
seen = set()
49+
unique_queries: list[str] = []
50+
for query in model_response.queries:
51+
if query not in seen:
52+
seen.add(query)
53+
unique_queries.append(query)
54+
unique_queries_len = len(unique_queries)
55+
if unique_queries_len != num_queries_generate_per_doc:
56+
log.info(f"Expected {num_queries_generate_per_doc} unique queries, got {unique_queries_len}")
57+
58+
return LLMQueryResponse(response_content=json.dumps(unique_queries))
4859

4960
def generate_score(self, document: Document, query: str, relevance_scale: str,
5061
explanation: bool = False) -> LLMScoreResponse:
5162
"""
5263
Generates a relevance score for a given document-query pair using a specified relevance scale.
5364
If explanation flag is set to true, score explanation is generated as well.
5465
"""
55-
if relevance_scale == "binary":
56-
description = (" - 0: the query is NOT relevant to the given document\n"
57-
" - 1: the query is relevant to the given document")
58-
elif relevance_scale == "graded":
59-
description = (" - 0: the query is NOT relevant to the given document\n"
60-
" - 1: the query may be relevant to the given document\n"
61-
" - 2: the document proposed is the answer to the query")
62-
else:
63-
msg = f"Invalid relevance scale: {relevance_scale}"
64-
log.error(msg)
65-
raise ValueError(msg)
66+
if relevance_scale not in {"binary", "graded"}:
67+
raise ValueError(f"Invalid relevance scale: {relevance_scale}")
68+
69+
schema: type[BaseModel] = BinaryScore if relevance_scale == "binary" else GradedScore
6670

6771
system_prompt = (f"You are a professional data labeler and, given a document with a set of fields and a query "
6872
f"and you need to return the relevance score in a scale called {relevance_scale.upper()}. "
69-
f"The scores of this scale are built as follows:\n{description}\n")
70-
73+
" Return a structured object matching the provided schema.")
7174
if explanation:
7275
system_prompt += (
73-
"Return ONLY a **valid JSON** object with two keys:"
74-
" `score`: the related score as an integer value\n"
75-
" `explanation`: your concise explanation for that score\n"
76-
"As an example, I expect a JSON response like the following: "
77-
"{\"score\": \"integer value\",\"explanation\": \"I rated this score because...\" }"
76+
" Include a clear explanation justifying your score "
77+
"in the `explanation` field based on the provided schema."
7878
)
7979
else:
8080
system_prompt += (
81-
"Return ONLY a **valid JSON** object with key 'score' and the related score as an integer value."
82-
"I expect a JSON response like the following: {\"score\": \"integer value\"}"
81+
" Do not include any explanation."
8382
)
8483

8584
messages = [
@@ -92,24 +91,16 @@ def generate_score(self, document: Document, query: str, relevance_scale: str,
9291
)
9392
]
9493

95-
response_content = self.chat_model.invoke(messages).content
96-
if isinstance(response_content, str):
97-
raw = response_content.strip()
98-
else:
99-
raw = json.dumps(response_content)
100-
94+
# Use LangChain structured output
95+
structured_llm = self.chat_model.with_structured_output(schema)
10196
try:
102-
score = json.loads(raw)['score']
103-
score_explanation = None
104-
if explanation:
105-
score_explanation = json.loads(raw)['explanation']
106-
except (JSONDecodeError, KeyError) as e:
107-
log.debug(f"LLM unexpected response. Raw output: {raw}")
97+
model_response = structured_llm.invoke(messages)
98+
except (ValidationError, KeyError) as e:
99+
log.debug("Invalid LLM response.")
108100
raise ValueError(f"Invalid LLM response: {e}")
109101

110-
try:
111-
parsed = LLMScoreResponse(score=score, scale=relevance_scale, explanation=score_explanation)
112-
return parsed
113-
except ValueError as e:
114-
log.warning(f"Validation error for score '{score}' on scale '{relevance_scale}': {e}")
115-
raise e
102+
return LLMScoreResponse(
103+
score=model_response.score,
104+
scale=relevance_scale,
105+
explanation=(model_response.explanation if explanation else None)
106+
)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List, Optional
2+
3+
from langchain_core.language_models import BaseChatModel
4+
from langchain_core.outputs import ChatResult
5+
from pydantic import BaseModel
6+
7+
8+
class _StructuredOutputMockLLM:
9+
def __init__(self, fake_chat_model, schema: type[BaseModel]):
10+
self._fake_chat_model = fake_chat_model
11+
self._schema = schema
12+
13+
def invoke(self, messages):
14+
payload = self._fake_chat_model.responses.pop(0)
15+
16+
if isinstance(payload, self._schema):
17+
return payload
18+
19+
if isinstance(payload, dict):
20+
return self._schema.model_validate(payload)
21+
22+
if isinstance(payload, str):
23+
return self._schema.model_validate_json(payload)
24+
25+
raise TypeError(f"Unexpected fake payload type: {type(payload)}")
26+
27+
28+
class ChatModelAdapter(BaseChatModel):
29+
"""Fake adapter for with_structured_output, as the FakeListChatModel doesn't support"""
30+
31+
def __init__(self, fake_chat_model):
32+
super().__init__()
33+
self._fake_chat_model = fake_chat_model
34+
35+
@property
36+
def _llm_type(self) -> str:
37+
return "fake_adapter"
38+
39+
def _generate(self, messages: List, stop: Optional[List[str]] = None, **kwargs) -> ChatResult:
40+
raise NotImplementedError("_generate is not used in the test")
41+
42+
def with_structured_output(self, schema: type[BaseModel]):
43+
return _StructuredOutputMockLLM(self._fake_chat_model, schema)

rre-tools/dataset-generator/tests/unit/test_llm_service.py renamed to rre-tools/dataset-generator/tests/unit/llm/test_llm_service.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import pytest
12
from langchain_core.language_models.fake_chat_models import FakeListChatModel
2-
from dataset_generator.llm import LLMService
3+
34
from commons.model import Document, LLMQueryResponse, LLMScoreResponse
4-
import pytest
5+
from dataset_generator.llm import LLMService
6+
from llm_mock import ChatModelAdapter
57

68

79
@pytest.fixture
@@ -18,18 +20,18 @@ def example_doc():
1820

1921
def test_llm_service_generate_queries__expects__response(example_doc):
2022
# Test that the service can generate queries from a document
21-
fake_llm = FakeListChatModel(responses=['["Car"]'])
22-
service = LLMService(chat_model=fake_llm)
23+
fake_llm = FakeListChatModel(responses=['{"queries": ["Car","Auto","Vehicle","Sedan","Toyota"]}'])
24+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
2325

2426
response = service.generate_queries(example_doc, 5)
2527

2628
assert isinstance(response, LLMQueryResponse)
27-
assert response.get_queries() == ["Car"]
29+
assert response.get_queries() == ["Car","Auto","Vehicle","Sedan","Toyota"]
2830

2931

3032
def test_llm_service_generate_score__expects__response(example_doc):
31-
fake_llm = FakeListChatModel(responses=["{\"score\": 1}"])
32-
service = LLMService(chat_model=fake_llm)
33+
fake_llm = FakeListChatModel(responses=['{"score": 1}'])
34+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
3335

3436
query = "Is a Toyota the car of the year?"
3537

@@ -45,7 +47,7 @@ def test_llm_service_generate_score__expects__response(example_doc):
4547
])
4648
def test_llm_service_generate_score_with_invalid_responses__expects__raises_value_error(example_doc, invalid_response):
4749
fake_llm = FakeListChatModel(responses=[invalid_response])
48-
service = LLMService(chat_model=fake_llm)
50+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
4951

5052
query = "Is a Toyota the car of the year?"
5153
with pytest.raises(ValueError):

rre-tools/dataset-generator/tests/unit/llm/test_llm_service_queries.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from langchain_core.language_models.fake_chat_models import FakeListChatModel
33
from dataset_generator.llm import LLMService
44
from commons.model import Document, LLMQueryResponse
5+
from llm_mock import ChatModelAdapter
56

67

78
@pytest.fixture
@@ -16,44 +17,42 @@ def example_doc():
1617

1718

1819
def test_llm_service_generate_queries__expects__valid(example_doc):
19-
fake_llm = FakeListChatModel(responses=['["Toyota", "Best Car"]'])
20-
service = LLMService(chat_model=fake_llm)
20+
fake_llm = FakeListChatModel(responses=['{"queries": ["Toyota", "Best Car"]}'])
21+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
2122
response = service.generate_queries(example_doc, 2)
2223

2324
assert isinstance(response, LLMQueryResponse)
2425
assert response.get_queries() == ["Toyota", "Best Car"]
2526

2627

2728
def test_llm_service_generate_queries__expects__empty_list(example_doc):
28-
fake_llm = FakeListChatModel(responses=['[]'])
29-
service = LLMService(chat_model=fake_llm)
29+
fake_llm = FakeListChatModel(responses=['{"queries":[]}'])
30+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
3031
response = service.generate_queries(example_doc, 0)
3132
assert response.get_queries() == []
3233

3334

3435
@pytest.mark.parametrize("invalid_response, expected_error", [
35-
('not a json', "Invalid JSON in `response_content`"),
36-
('["", " ", "Valid"]', "must not be empty or only whitespace"),
37-
('["Good", 123, null]', "must be strings"),
36+
('not a json', r"Invalid JSON"),
37+
('{"queries":["", " ", "Valid"]}', r"(at least 1 character|min_length|String should have at least 1)"),
38+
('{"queries":["Good", 123, null]}', r"(valid string|string_type)"),
3839
])
3940
def test_llm_service_generate_queries_with_invalid_responses__expects__error(invalid_response, expected_error, example_doc):
4041
fake_llm = FakeListChatModel(responses=[invalid_response])
41-
service = LLMService(chat_model=fake_llm)
42+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
4243
with pytest.raises(ValueError, match=expected_error):
4344
service.generate_queries(example_doc, 3)
4445

4546

4647
def test_generate_queries_with_unicode_strings__expects__list_of_unicode_strings(example_doc):
47-
unicode_list = '["こんにちは", "你好", "¡Hola!"]'
48-
fake_llm = FakeListChatModel(responses=[unicode_list])
49-
service = LLMService(chat_model=fake_llm)
48+
fake_llm = FakeListChatModel(responses=['{"queries":["こんにちは", "你好", "¡Hola!"]}'])
49+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
5050
response = service.generate_queries(example_doc, 3)
5151
assert response.get_queries() == ["こんにちは", "你好", "¡Hola!"]
5252

5353

54-
def test_generate_queries_with_leading_trailing_whitespace__expects__strings_preserved(example_doc):
55-
list_with_whitespace = '[" hello ", " world "]'
56-
fake_llm = FakeListChatModel(responses=[list_with_whitespace])
57-
service = LLMService(chat_model=fake_llm)
54+
def test_generate_queries_with_leading_trailing_whitespace__expects__whitespace_stripped(example_doc):
55+
fake_llm = FakeListChatModel(responses=['{"queries":[" hello ", " world "]}'])
56+
service = LLMService(chat_model=ChatModelAdapter(fake_llm))
5857
response = service.generate_queries(example_doc, 2)
59-
assert response.get_queries() == [" hello ", " world "]
58+
assert response.get_queries() == ["hello", "world"]

0 commit comments

Comments
 (0)