Skip to content

Commit 17c2962

Browse files
fgebhartWieslerAA
authored andcommitted
feat: adjust json schema for structured output
1 parent b3b2089 commit 17c2962

File tree

2 files changed

+69
-57
lines changed

2 files changed

+69
-57
lines changed
Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import asdict, dataclass
2-
from typing import Any, Mapping, Union
2+
from typing import Any, Mapping, Optional, Union
3+
34

45
@dataclass(frozen=True)
56
class JSONSchema:
@@ -11,15 +12,21 @@ class JSONSchema:
1112
JSON schema that structured output must adhere to.
1213
1314
Examples:
14-
>>> schema = [
15-
>>> JSONSchema(
16-
>>> json_schema={'properties': {'bar': {'type': 'integer'}, 'type': 'object'}}
17-
>>> )
15+
>>> schema = JSONSchema(
16+
>>> schema={'type': 'object', 'properties': {'bar': {'type': 'integer'}}},
17+
>>> name="example_schema",
18+
>>> description="Example schema with a bar integer property",
19+
>>> strict=True
20+
>>> )
1821
"""
1922

20-
json_schema: Mapping[str, Any]
23+
schema: Mapping[str, Any]
24+
name: str
25+
description: Optional[str] = None
26+
strict: Optional[bool] = False
2127

2228
def to_json(self) -> Mapping[str, Any]:
23-
return {"type": "json_schema", **asdict(self)}
29+
return {"type": "json_schema", "json_schema": asdict(self)}
30+
2431

2532
ResponseFormat = Union[JSONSchema]

tests/test_chat.py

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -197,49 +197,52 @@ def test_steering_chat(sync_client: Client, chat_model_name: str):
197197
assert base_completion_result != steered_completion_result
198198

199199

200-
def test_response_format_json_schema(sync_client: Client, dummy_model_name: str):
201-
# This example is taken from json-schema.org:
202-
example_json_schema = {
203-
"$schema": "https://json-schema.org/draft/2020-12/schema",
204-
"$id": "https://example.com/product.schema.json",
205-
"title": "Product",
206-
"description": "A product from Acme's catalog",
207-
"type": "object",
208-
"properties": {
209-
"productId": {
210-
"description": "The unique identifier for a product",
211-
"type": "integer"
212-
}
213-
}
214-
}
200+
def test_response_format_json_schema(sync_client: Client, chat_model_name: str):
201+
chat_model_name = "pharia-chat-qwen3-32b-0801"
202+
example_json_schema = {"properties": {"nemo": {"type": "string"}}}
215203

216204
request = ChatRequest(
217-
messages=[Message(role=Role.User, content="Give me JSON!")],
218-
model=dummy_model_name,
219-
response_format=JSONSchema(example_json_schema),
205+
messages=[
206+
Message(role=Role.System, content="You are a helpful assistant."),
207+
Message(
208+
role=Role.User,
209+
content=f"Give me JSON {example_json_schema}! Tell me about nemo",
210+
),
211+
],
212+
model=chat_model_name,
213+
response_format=JSONSchema(
214+
schema=example_json_schema,
215+
name="test_schema",
216+
description="Test schema for JSON response",
217+
strict=False,
218+
),
220219
)
221220

222-
response = sync_client.chat(request, model=dummy_model_name)
223-
224-
# Dummy worker simply returns the JSON schema that the user has submitted
225-
assert json.loads(response.message.content) == example_json_schema
221+
response = sync_client.chat(request, model=chat_model_name)
222+
json_response = json.loads(response.message.content)
223+
assert "nemo" in json_response.keys()
224+
assert isinstance(json_response["nemo"], str)
226225

227226

228227
@pytest.mark.parametrize(
229228
"generic_client", ["sync_client", "async_client"], indirect=True
230229
)
231-
async def test_can_chat_with_images(generic_client: GenericClient, dummy_model_name: str):
230+
async def test_can_chat_with_images(
231+
generic_client: GenericClient, dummy_model_name: str
232+
):
232233
image_path = Path(__file__).parent / "dog-and-cat-cover.jpg"
233234
image = Image.open(image_path)
234235

235236
request = ChatRequest(
236-
messages=[Message(
237-
role=Role.User,
238-
content=[
239-
"Describe the following image.",
240-
image,
241-
],
242-
)],
237+
messages=[
238+
Message(
239+
role=Role.User,
240+
content=[
241+
"Describe the following image.",
242+
image,
243+
],
244+
)
245+
],
243246
model=dummy_model_name,
244247
maximum_tokens=200,
245248
)
@@ -263,14 +266,17 @@ def test_multimodal_message_serialization() -> None:
263266
content=[
264267
"Describe the following image.",
265268
Image.open(image_path),
266-
]
269+
],
267270
)
268271
assert message.to_json() == {
269272
"role": "user",
270273
"content": [
271274
{"type": "text", "text": "Describe the following image."},
272-
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{TINY_PNG}"}}
273-
]
275+
{
276+
"type": "image_url",
277+
"image_url": {"url": f"data:image/png;base64,{TINY_PNG}"},
278+
},
279+
],
274280
}
275281

276282

@@ -281,11 +287,15 @@ def test_multimodal_message_serialization_unknown_type() -> None:
281287
content=[
282288
"Describe the following image.",
283289
Path(image_path), # type: ignore
284-
]
290+
],
285291
)
286292
with pytest.raises(ValueError) as e:
287293
message.to_json()
288-
assert str(e.value) == "The item in the prompt is not valid. Try either a string or an Image."
294+
assert (
295+
str(e.value)
296+
== "The item in the prompt is not valid. Try either a string or an Image."
297+
)
298+
289299

290300
def test_request_serialization_no_default_values() -> None:
291301
request = ChatRequest(
@@ -294,44 +304,39 @@ def test_request_serialization_no_default_values() -> None:
294304
)
295305
assert request.to_json() == {
296306
"model": "dummy-model",
297-
"messages": [
298-
{
299-
"role": "user",
300-
"content": "Hello, how are you?"
301-
}
302-
]
307+
"messages": [{"role": "user", "content": "Hello, how are you?"}],
303308
}
304309

305310

306-
def test_multi_turn_chat_serialization(sync_client: Client, dummy_model_name: str):
311+
def test_multi_turn_chat_serialization(sync_client: Client, chat_model_name: str):
307312
"""
308313
Test that TextMessage can be serialized when included in multi-turn chat history.
309314
310315
We previously encountered an error in a multi-turn chat conversation.
311-
The returned TextMessage could not be made part of the chat history for the
312-
next request as the method for serialization was missing.
316+
The returned TextMessage could not be made part of the chat history for the
317+
next request as the method for serialization was missing.
313318
This test should catch such conversion issues.
314319
"""
315320

316321
# First turn
317322
first_request = ChatRequest(
318323
messages=[Message(role=Role.User, content="Hello")],
319-
model=dummy_model_name,
324+
model=chat_model_name,
320325
)
321-
first_response = sync_client.chat(first_request, model=dummy_model_name)
322-
326+
first_response = sync_client.chat(first_request, model=chat_model_name)
327+
323328
# Second turn - includes the TextMessage from first response in history
324329
messages_with_history: List[Union[Message, TextMessage]] = [
325330
Message(role=Role.User, content="Hello"),
326331
first_response.message, # This TextMessage must be serializable
327332
Message(role=Role.User, content="Follow up question"),
328333
]
329-
334+
330335
second_request = ChatRequest(
331336
messages=messages_with_history,
332-
model=dummy_model_name,
337+
model=chat_model_name,
333338
)
334-
339+
335340
# This would fail if TextMessage.to_json() doesn't exist
336341
serialized = second_request.to_json()
337342
assert len(serialized["messages"]) == 3

0 commit comments

Comments
 (0)