22from dataclasses import asdict , dataclass
33from enum import Enum
44from io import BytesIO
5- from typing import Any , Dict , List , Mapping , Optional , Sequence , Union
5+ from typing import Any , Dict , List , Mapping , Optional , Sequence , Union , Literal
66
77from pydantic import BaseModel
88from aleph_alpha_client .structured_output import ResponseFormat
@@ -41,6 +41,40 @@ def to_json(self) -> Mapping[str, Any]:
4141 return result
4242
4343
44+ @dataclass (frozen = True )
45+ class FunctionCall :
46+ name : str
47+ arguments : str
48+
49+
50+ @dataclass (frozen = True )
51+ class ToolCall :
52+ id : str
53+ type : str
54+ function : FunctionCall
55+
56+ @staticmethod
57+ def from_json (json : Dict [str , Any ]) -> "ToolCall" :
58+ function = json ["function" ]
59+ return ToolCall (
60+ id = json ["id" ],
61+ type = json ["type" ],
62+ function = FunctionCall (
63+ name = function ["name" ], arguments = function ["arguments" ]
64+ ),
65+ )
66+
67+ def to_json (self ) -> Mapping [str , Any ]:
68+ return {
69+ "id" : self .id ,
70+ "type" : self .type ,
71+ "function" : {
72+ "name" : self .function .name ,
73+ "arguments" : self .function .arguments ,
74+ },
75+ }
76+
77+
4478# We introduce a more specific message type because chat responses can only
4579# contain text at the moment. This enables static type checking to proof that
4680# `content` is always a string.
@@ -59,12 +93,17 @@ class TextMessage:
5993
6094 role : Role
6195 content : str
96+ tool_calls : Optional [List [ToolCall ]] = None
6297
6398 @staticmethod
6499 def from_json (json : Dict [str , Any ]) -> "TextMessage" :
100+ tool_calls = json .get ("tool_calls" )
65101 return TextMessage (
66102 role = Role (json ["role" ]),
67103 content = json ["content" ],
104+ tool_calls = None
105+ if tool_calls is None
106+ else [ToolCall .from_json (tool_call ) for tool_call in tool_calls ],
68107 )
69108
70109 # In multi-turn conversations the returned TextMessage is part of the chat
@@ -76,6 +115,8 @@ def to_json(self) -> Mapping[str, Any]:
76115 "role" : self .role .value ,
77116 "content" : _message_content_to_json (self .content ),
78117 }
118+ if self .tool_calls is not None :
119+ result ["tool_calls" ] = [t .to_json () for t in self .tool_calls ]
79120 return result
80121
81122
@@ -122,6 +163,12 @@ class StreamOptions:
122163 include_usage : bool
123164
124165
166+ @dataclass (frozen = True )
167+ class ToolFunction :
168+ type : Literal ["function" ]
169+ function : Any
170+
171+
125172@dataclass (frozen = True )
126173class ChatRequest :
127174 """
@@ -141,6 +188,12 @@ class ChatRequest:
141188 steering_concepts : Optional [List [str ]] = None
142189 response_format : Optional [ResponseFormat ] = None
143190
191+ tools : Optional [List [Any ]] = None
192+ tool_choice : Optional [Union [Literal ["auto" , "required" , "none" ], ToolFunction ]] = (
193+ None
194+ )
195+ parallel_tool_calls : Optional [bool ] = None
196+
144197 def to_json (self ) -> Mapping [str , Any ]:
145198 payload = {k : v for k , v in asdict (self ).items () if v is not None }
146199 payload ["messages" ] = [message .to_json () for message in self .messages ]
@@ -164,7 +217,7 @@ class FinishReason(str, Enum):
164217 """
165218 The reason the model stopped generating tokens.
166219
167- This will be stop if the model hit a natural stop point or a provided stop
220+ This will be ` stop` if the model hit a natural stop point or a provided stop
168221 sequence or length if the maximum number of tokens specified in the request
169222 was reached. If the API is unable to understand the stop reason emitted by
170223 one of the workers, content_filter is returned.
@@ -173,6 +226,7 @@ class FinishReason(str, Enum):
173226 Stop = "stop"
174227 Length = "length"
175228 ContentFilter = "content_filter"
229+ ToolCalls = "tool_calls"
176230
177231
178232@dataclass (frozen = True )
0 commit comments