Skip to content

Commit 2688d30

Browse files
committed
Gemini working.
1 parent 25898d4 commit 2688d30

File tree

1 file changed

+130
-17
lines changed

1 file changed

+130
-17
lines changed

aisuite/providers/google_provider.py

Lines changed: 130 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@
1313
Tool,
1414
FunctionDeclaration,
1515
)
16+
import pprint
1617

17-
from aisuite.framework import ProviderInterface, ChatCompletionResponse
18+
from aisuite.framework import ProviderInterface, ChatCompletionResponse, Message
1819

1920

2021
DEFAULT_TEMPERATURE = 0.7
22+
ENABLE_DEBUG_MESSAGES = False
23+
24+
# Links.
25+
# https://codelabs.developers.google.com/codelabs/gemini-function-calling#6
26+
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/function-calling#chat-samples
2127

2228

2329
class GoogleMessageConverter:
@@ -30,22 +36,47 @@ def convert_user_role_message(message: Dict[str, Any]) -> Content:
3036
@staticmethod
3137
def convert_assistant_role_message(message: Dict[str, Any]) -> Content:
3238
"""Convert assistant messages to Google Vertex AI format."""
33-
parts = [Part.from_text(message["content"])]
39+
if "tool_calls" in message and message["tool_calls"]:
40+
# Handle function calls
41+
tool_call = message["tool_calls"][
42+
0
43+
] # Assuming single function call for now
44+
function_call = tool_call["function"]
45+
46+
# Create a Part from the function call
47+
parts = [
48+
Part.from_dict(
49+
{
50+
"function_call": {
51+
"name": function_call["name"],
52+
# "arguments": json.loads(function_call["arguments"])
53+
}
54+
}
55+
)
56+
]
57+
# return Content(role="function", parts=parts)
58+
else:
59+
# Handle regular text messages
60+
parts = [Part.from_text(message["content"])]
61+
# return Content(role="model", parts=parts)
62+
3463
return Content(role="model", parts=parts)
3564

3665
@staticmethod
37-
def convert_tool_role_message(message: Dict[str, Any]) -> Optional[Content]:
66+
def convert_tool_role_message(message: Dict[str, Any]) -> Part:
3867
"""Convert tool messages to Google Vertex AI format."""
3968
if "content" not in message:
40-
return None
69+
raise ValueError("Tool result message must have a content field")
4170

4271
try:
4372
content_json = json.loads(message["content"])
44-
parts = [Part.from_function_response(content_json)]
73+
part = Part.from_function_response(
74+
name=message["name"], response=content_json
75+
)
76+
# TODO: Return Content instead of Part. But returning Content is not working.
77+
return part
4578
except json.JSONDecodeError:
46-
parts = [Part.from_text(message["content"])]
47-
48-
return Content(role="user", parts=parts)
79+
raise ValueError("Tool result message must be valid JSON")
4980

5081
@staticmethod
5182
def convert_request(messages: List[Dict[str, Any]]) -> List[Content]:
@@ -80,9 +111,55 @@ def convert_response(response) -> ChatCompletionResponse:
80111
"""Normalize the response from Vertex AI to match OpenAI's response format."""
81112
openai_response = ChatCompletionResponse()
82113

114+
if ENABLE_DEBUG_MESSAGES:
115+
print("Dumping the response")
116+
pprint.pprint(response)
117+
118+
# TODO: We need to go through each part, because function call may not be the first part.
119+
# Currently, we are only handling the first part, but this is not enough.
120+
#
121+
# This is a valid response:
122+
# candidates {
123+
# content {
124+
# role: "model"
125+
# parts {
126+
# text: "The current temperature in San Francisco is 72 degrees Celsius. \n\n"
127+
# }
128+
# parts {
129+
# function_call {
130+
# name: "is_it_raining"
131+
# args {
132+
# fields {
133+
# key: "location"
134+
# value {
135+
# string_value: "San Francisco"
136+
# }
137+
# }
138+
# }
139+
# }
140+
# }
141+
# }
142+
# finish_reason: STOP
143+
83144
# Check if the response contains function calls
84-
if hasattr(response.candidates[0].content, "function_call"):
85-
function_call = response.candidates[0].content.function_call
145+
# Note: Just checking if the function_call attribute exists is not enough,
146+
# it is important to check if the function_call is not None.
147+
if (
148+
hasattr(response.candidates[0].content.parts[0], "function_call")
149+
and response.candidates[0].content.parts[0].function_call
150+
):
151+
function_call = response.candidates[0].content.parts[0].function_call
152+
153+
# args is a MapComposite.
154+
# Convert the MapComposite to a dictionary
155+
args_dict = {}
156+
# Another way to try is: args_dict = dict(function_call.args)
157+
for key, value in function_call.args.items():
158+
args_dict[key] = value
159+
if ENABLE_DEBUG_MESSAGES:
160+
print("Dumping the args_dict")
161+
pprint.pprint(args_dict)
162+
86163
openai_response.choices[0].message = {
87164
"role": "assistant",
88165
"content": None,
@@ -92,11 +169,15 @@ def convert_response(response) -> ChatCompletionResponse:
92169
"id": f"call_{hash(function_call.name)}", # Generate a unique ID
93170
"function": {
94171
"name": function_call.name,
95-
"arguments": json.dumps(function_call.args),
172+
"arguments": json.dumps(args_dict),
96173
},
97174
}
98175
],
176+
"refusal": None,
99177
}
178+
openai_response.choices[0].message = Message(
179+
**openai_response.choices[0].message
180+
)
100181
openai_response.choices[0].finish_reason = "tool_calls"
101182
else:
102183
# Handle regular text response
@@ -160,26 +241,58 @@ def chat_completions_create(self, model, messages, **kwargs):
160241
FunctionDeclaration(
161242
name=tool["function"]["name"],
162243
description=tool["function"].get("description", ""),
163-
parameters=tool["function"]["parameters"],
244+
parameters={
245+
"type": "object",
246+
"properties": {
247+
param_name: {
248+
"type": param_info.get("type", "string"),
249+
"description": param_info.get(
250+
"description", ""
251+
),
252+
**(
253+
{"enum": param_info["enum"]}
254+
if "enum" in param_info
255+
else {}
256+
),
257+
}
258+
for param_name, param_info in tool["function"][
259+
"parameters"
260+
]["properties"].items()
261+
},
262+
"required": tool["function"]["parameters"].get(
263+
"required", []
264+
),
265+
},
164266
)
267+
for tool in kwargs["tools"]
165268
]
166269
)
167-
for tool in kwargs["tools"]
168270
]
169271

170-
print(tools)
171272
# Create the GenerativeModel
172273
model = GenerativeModel(
173274
model,
174275
generation_config=GenerationConfig(temperature=temperature),
175-
# tools=tools
276+
tools=tools,
176277
)
177278

279+
if ENABLE_DEBUG_MESSAGES:
280+
print("Dumping the message_history")
281+
pprint.pprint(message_history)
282+
178283
# Start chat and get response
179284
chat = model.start_chat(history=message_history[:-1])
180-
response = chat.send_message(
181-
message_history[-1].parts[0].text,
285+
last_message = message_history[-1]
286+
287+
# If the last message is a function response, send the Part object directly
288+
# Otherwise, send just the text content
289+
message_to_send = (
290+
Content(role="function", parts=[last_message])
291+
if isinstance(last_message, Part)
292+
else last_message.parts[0].text
182293
)
294+
# response = chat.send_message(message_to_send)
295+
response = chat.send_message(message_to_send)
183296

184297
# Convert and return the response
185298
return self.transformer.convert_response(response)

0 commit comments

Comments
 (0)