File tree Expand file tree Collapse file tree
aidial_sdk/chat_completion Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1414 MessageContentTextPart ,
1515 Request ,
1616 ResponseFormat ,
17+ ResponseFormatJsonObject ,
18+ ResponseFormatJsonSchema ,
19+ ResponseFormatJsonSchemaObject ,
20+ ResponseFormatText ,
1721 Role ,
1822)
1923from aidial_sdk .chat_completion .request import Stage as RequestStage
Original file line number Diff line number Diff line change 1010 ConstrainedFloat ,
1111 ConstrainedInt ,
1212 ConstrainedList ,
13+ Field ,
1314 PositiveInt ,
1415 StrictBool ,
1516 StrictInt ,
@@ -170,8 +171,35 @@ class ToolChoice(ExtraForbidModel):
170171 function : FunctionChoice
171172
172173
173- class ResponseFormat (ExtraForbidModel ):
174- type : Literal ["text" , "json_object" ]
174+ class ResponseFormatText (ExtraForbidModel ):
175+ type : Literal ["text" ]
176+
177+
178+ class ResponseFormatJsonObject (ExtraForbidModel ):
179+ type : Literal ["json_object" ]
180+
181+
182+ class ResponseFormatJsonSchemaObject (ExtraForbidModel ):
183+ description : Optional [StrictStr ] = None
184+ name : StrictStr
185+ schema_ : Dict [str , Any ] = Field (..., alias = "schema" )
186+ strict : Optional [StrictBool ] = False
187+
188+ def dict (self , * args , ** kwargs ):
189+ kwargs ["by_alias" ] = True
190+ return super ().dict (* args , ** kwargs )
191+
192+
193+ class ResponseFormatJsonSchema (ExtraForbidModel ):
194+ type : Literal ["json_schema" ]
195+ json_schema : ResponseFormatJsonSchemaObject
196+
197+
198+ ResponseFormat = Union [
199+ ResponseFormatText ,
200+ ResponseFormatJsonObject ,
201+ ResponseFormatJsonSchema ,
202+ ]
175203
176204
177205class AzureChatCompletionRequest (ExtraForbidModel ):
Original file line number Diff line number Diff line change 11import json
22
3- from aidial_sdk .chat_completion import Message , Role
3+ from aidial_sdk .chat_completion import Message , ResponseFormatJsonSchema , Role
4+ from aidial_sdk .chat_completion .request import ResponseFormatJsonSchemaObject
45
56
67def test_message_ser ():
@@ -17,3 +18,53 @@ def test_message_deser():
1718 expected_obj = Message (role = Role .SYSTEM , content = "test" )
1819
1920 assert actual_obj == expected_obj
21+
22+
23+ def test_response_format_serialization ():
24+ format_obj = ResponseFormatJsonSchema (
25+ type = "json_schema" ,
26+ json_schema = ResponseFormatJsonSchemaObject (
27+ description = "desc" ,
28+ name = "name" ,
29+ schema = {"key" : "value" },
30+ ),
31+ )
32+
33+ actual_dict = format_obj .dict ()
34+
35+ expected_dict = {
36+ "type" : "json_schema" ,
37+ "json_schema" : {
38+ "description" : "desc" ,
39+ "name" : "name" ,
40+ "schema" : {"key" : "value" },
41+ "strict" : False ,
42+ },
43+ }
44+
45+ assert actual_dict == expected_dict
46+
47+
48+ def test_response_format_deserialization ():
49+ format_dict = {
50+ "type" : "json_schema" ,
51+ "json_schema" : {
52+ "description" : "desc" ,
53+ "name" : "name" ,
54+ "schema" : {"key" : "value" },
55+ },
56+ }
57+
58+ actual_obj = ResponseFormatJsonSchema .parse_obj (format_dict )
59+
60+ expected_obj = ResponseFormatJsonSchema (
61+ type = "json_schema" ,
62+ json_schema = ResponseFormatJsonSchemaObject (
63+ description = "desc" ,
64+ name = "name" ,
65+ schema = {"key" : "value" },
66+ strict = False ,
67+ ),
68+ )
69+
70+ assert actual_obj == expected_obj
You can’t perform that action at this time.
0 commit comments