1313 Tool ,
1414 FunctionDeclaration ,
1515)
16+ import pprint
1617
17- from aisuite .framework import ProviderInterface , ChatCompletionResponse
18+ from aisuite .framework import ProviderInterface , ChatCompletionResponse , Message
1819
1920
2021DEFAULT_TEMPERATURE = 0.7
22+ ENABLE_DEBUG_MESSAGES = False
23+
24+ # Links.
25+ # https://codelabs.developers.google.com/codelabs/gemini-function-calling#6
26+ # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#chat-samples
2127
2228
2329class GoogleMessageConverter :
@@ -30,22 +36,47 @@ def convert_user_role_message(message: Dict[str, Any]) -> Content:
3036 @staticmethod
3137 def convert_assistant_role_message (message : Dict [str , Any ]) -> Content :
3238 """Convert assistant messages to Google Vertex AI format."""
33- parts = [Part .from_text (message ["content" ])]
39+ if "tool_calls" in message and message ["tool_calls" ]:
40+ # Handle function calls
41+ tool_call = message ["tool_calls" ][
42+ 0
43+ ] # Assuming single function call for now
44+ function_call = tool_call ["function" ]
45+
46+ # Create a Part from the function call
47+ parts = [
48+ Part .from_dict (
49+ {
50+ "function_call" : {
51+ "name" : function_call ["name" ],
52+ # "arguments": json.loads(function_call["arguments"])
53+ }
54+ }
55+ )
56+ ]
57+ # return Content(role="function", parts=parts)
58+ else :
59+ # Handle regular text messages
60+ parts = [Part .from_text (message ["content" ])]
61+ # return Content(role="model", parts=parts)
62+
3463 return Content (role = "model" , parts = parts )
3564
3665 @staticmethod
37- def convert_tool_role_message (message : Dict [str , Any ]) -> Optional [ Content ] :
66+ def convert_tool_role_message (message : Dict [str , Any ]) -> Part :
3867 """Convert tool messages to Google Vertex AI format."""
3968 if "content" not in message :
40- return None
69+ raise ValueError ( "Tool result message must have a content field" )
4170
4271 try :
4372 content_json = json .loads (message ["content" ])
44- parts = [Part .from_function_response (content_json )]
73+ part = Part .from_function_response (
74+ name = message ["name" ], response = content_json
75+ )
76+ # TODO: Return Content instead of Part. But returning Content is not working.
77+ return part
4578 except json .JSONDecodeError :
46- parts = [Part .from_text (message ["content" ])]
47-
48- return Content (role = "user" , parts = parts )
79+ raise ValueError ("Tool result message must be valid JSON" )
4980
5081 @staticmethod
5182 def convert_request (messages : List [Dict [str , Any ]]) -> List [Content ]:
@@ -80,9 +111,55 @@ def convert_response(response) -> ChatCompletionResponse:
80111 """Normalize the response from Vertex AI to match OpenAI's response format."""
81112 openai_response = ChatCompletionResponse ()
82113
114+ if ENABLE_DEBUG_MESSAGES :
115+ print ("Dumping the response" )
116+ pprint .pprint (response )
117+
118+ # TODO: We need to go through each part, because function call may not be the first part.
119+ # Currently, we are only handling the first part, but this is not enough.
120+ #
121+ # This is a valid response:
122+ # candidates {
123+ # content {
124+ # role: "model"
125+ # parts {
126+ # text: "The current temperature in San Francisco is 72 degrees Celsius. \n\n"
127+ # }
128+ # parts {
129+ # function_call {
130+ # name: "is_it_raining"
131+ # args {
132+ # fields {
133+ # key: "location"
134+ # value {
135+ # string_value: "San Francisco"
136+ # }
137+ # }
138+ # }
139+ # }
140+ # }
141+ # }
142+ # finish_reason: STOP
143+
83144 # Check if the response contains function calls
84- if hasattr (response .candidates [0 ].content , "function_call" ):
85- function_call = response .candidates [0 ].content .function_call
145+ # Note: Just checking if the function_call attribute exists is not enough,
146+ # it is important to check if the function_call is not None.
147+ if (
148+ hasattr (response .candidates [0 ].content .parts [0 ], "function_call" )
149+ and response .candidates [0 ].content .parts [0 ].function_call
150+ ):
151+ function_call = response .candidates [0 ].content .parts [0 ].function_call
152+
153+ # args is a MapComposite.
154+ # Convert the MapComposite to a dictionary
155+ args_dict = {}
156+ # Another way to try is: args_dict = dict(function_call.args)
157+ for key , value in function_call .args .items ():
158+ args_dict [key ] = value
159+ if ENABLE_DEBUG_MESSAGES :
160+ print ("Dumping the args_dict" )
161+ pprint .pprint (args_dict )
162+
86163 openai_response .choices [0 ].message = {
87164 "role" : "assistant" ,
88165 "content" : None ,
@@ -92,11 +169,15 @@ def convert_response(response) -> ChatCompletionResponse:
92169 "id" : f"call_{ hash (function_call .name )} " , # Generate a unique ID
93170 "function" : {
94171 "name" : function_call .name ,
95- "arguments" : json .dumps (function_call . args ),
172+ "arguments" : json .dumps (args_dict ),
96173 },
97174 }
98175 ],
176+ "refusal" : None ,
99177 }
178+ openai_response .choices [0 ].message = Message (
179+ ** openai_response .choices [0 ].message
180+ )
100181 openai_response .choices [0 ].finish_reason = "tool_calls"
101182 else :
102183 # Handle regular text response
@@ -160,26 +241,58 @@ def chat_completions_create(self, model, messages, **kwargs):
160241 FunctionDeclaration (
161242 name = tool ["function" ]["name" ],
162243 description = tool ["function" ].get ("description" , "" ),
163- parameters = tool ["function" ]["parameters" ],
244+ parameters = {
245+ "type" : "object" ,
246+ "properties" : {
247+ param_name : {
248+ "type" : param_info .get ("type" , "string" ),
249+ "description" : param_info .get (
250+ "description" , ""
251+ ),
252+ ** (
253+ {"enum" : param_info ["enum" ]}
254+ if "enum" in param_info
255+ else {}
256+ ),
257+ }
258+ for param_name , param_info in tool ["function" ][
259+ "parameters"
260+ ]["properties" ].items ()
261+ },
262+ "required" : tool ["function" ]["parameters" ].get (
263+ "required" , []
264+ ),
265+ },
164266 )
267+ for tool in kwargs ["tools" ]
165268 ]
166269 )
167- for tool in kwargs ["tools" ]
168270 ]
169271
170- print (tools )
171272 # Create the GenerativeModel
172273 model = GenerativeModel (
173274 model ,
174275 generation_config = GenerationConfig (temperature = temperature ),
175- # tools=tools
276+ tools = tools ,
176277 )
177278
279+ if ENABLE_DEBUG_MESSAGES :
280+ print ("Dumping the message_history" )
281+ pprint .pprint (message_history )
282+
178283 # Start chat and get response
179284 chat = model .start_chat (history = message_history [:- 1 ])
180- response = chat .send_message (
181- message_history [- 1 ].parts [0 ].text ,
285+ last_message = message_history [- 1 ]
286+
287+ # If the last message is a function response, send the Part object directly
288+ # Otherwise, send just the text content
289+ message_to_send = (
290+ Content (role = "function" , parts = [last_message ])
291+ if isinstance (last_message , Part )
292+ else last_message .parts [0 ].text
182293 )
294+ # response = chat.send_message(message_to_send)
295+ response = chat .send_message (message_to_send )
183296
184297 # Convert and return the response
185298 return self .transformer .convert_response (response )
0 commit comments