Skip to content

Commit 77398ba

Browse files
authored
feat: added json schema option for response_format (#217)
1 parent a420ec9 commit 77398ba

3 files changed

Lines changed: 86 additions & 3 deletions

File tree

aidial_sdk/chat_completion/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
MessageContentTextPart,
1515
Request,
1616
ResponseFormat,
17+
ResponseFormatJsonObject,
18+
ResponseFormatJsonSchema,
19+
ResponseFormatJsonSchemaObject,
20+
ResponseFormatText,
1721
Role,
1822
)
1923
from aidial_sdk.chat_completion.request import Stage as RequestStage

aidial_sdk/chat_completion/request.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ConstrainedFloat,
1111
ConstrainedInt,
1212
ConstrainedList,
13+
Field,
1314
PositiveInt,
1415
StrictBool,
1516
StrictInt,
@@ -170,8 +171,35 @@ class ToolChoice(ExtraForbidModel):
170171
function: FunctionChoice
171172

172173

173-
class ResponseFormat(ExtraForbidModel):
174-
type: Literal["text", "json_object"]
174+
class ResponseFormatText(ExtraForbidModel):
175+
type: Literal["text"]
176+
177+
178+
class ResponseFormatJsonObject(ExtraForbidModel):
179+
type: Literal["json_object"]
180+
181+
182+
class ResponseFormatJsonSchemaObject(ExtraForbidModel):
183+
description: Optional[StrictStr] = None
184+
name: StrictStr
185+
schema_: Dict[str, Any] = Field(..., alias="schema")
186+
strict: Optional[StrictBool] = False
187+
188+
def dict(self, *args, **kwargs):
189+
kwargs["by_alias"] = True
190+
return super().dict(*args, **kwargs)
191+
192+
193+
class ResponseFormatJsonSchema(ExtraForbidModel):
194+
type: Literal["json_schema"]
195+
json_schema: ResponseFormatJsonSchemaObject
196+
197+
198+
ResponseFormat = Union[
199+
ResponseFormatText,
200+
ResponseFormatJsonObject,
201+
ResponseFormatJsonSchema,
202+
]
175203

176204

177205
class AzureChatCompletionRequest(ExtraForbidModel):

tests/test_serialization.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22

3-
from aidial_sdk.chat_completion import Message, Role
3+
from aidial_sdk.chat_completion import Message, ResponseFormatJsonSchema, Role
4+
from aidial_sdk.chat_completion.request import ResponseFormatJsonSchemaObject
45

56

67
def test_message_ser():
@@ -17,3 +18,53 @@ def test_message_deser():
1718
expected_obj = Message(role=Role.SYSTEM, content="test")
1819

1920
assert actual_obj == expected_obj
21+
22+
23+
def test_response_format_serialization():
24+
format_obj = ResponseFormatJsonSchema(
25+
type="json_schema",
26+
json_schema=ResponseFormatJsonSchemaObject(
27+
description="desc",
28+
name="name",
29+
schema={"key": "value"},
30+
),
31+
)
32+
33+
actual_dict = format_obj.dict()
34+
35+
expected_dict = {
36+
"type": "json_schema",
37+
"json_schema": {
38+
"description": "desc",
39+
"name": "name",
40+
"schema": {"key": "value"},
41+
"strict": False,
42+
},
43+
}
44+
45+
assert actual_dict == expected_dict
46+
47+
48+
def test_response_format_deserialization():
49+
format_dict = {
50+
"type": "json_schema",
51+
"json_schema": {
52+
"description": "desc",
53+
"name": "name",
54+
"schema": {"key": "value"},
55+
},
56+
}
57+
58+
actual_obj = ResponseFormatJsonSchema.parse_obj(format_dict)
59+
60+
expected_obj = ResponseFormatJsonSchema(
61+
type="json_schema",
62+
json_schema=ResponseFormatJsonSchemaObject(
63+
description="desc",
64+
name="name",
65+
schema={"key": "value"},
66+
strict=False,
67+
),
68+
)
69+
70+
assert actual_obj == expected_obj

0 commit comments

Comments
 (0)