Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 107 additions & 14 deletions camel/models/gemini_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,47 @@ def _run(
"response_format", None
)
messages = self._process_messages(messages)

if response_format:
if tools:
raise ValueError(
"Gemini does not support function calling with "
"response format."
)
return self._request_parse(messages, response_format)
else:
return self._request_chat_completion(messages, tools)
# For tool calls, first get the tool response
completion = self._request_chat_completion(messages, tools)
if hasattr(completion, 'choices') and completion.choices:
choice = completion.choices[0]
if hasattr(choice, 'message') and (
isinstance(choice.message, dict)
and choice.message.get('tool_calls')
or hasattr(choice.message, 'tool_calls')
and getattr(choice.message, 'tool_calls', None)
):
# make a new request with response format
tool_messages = messages.copy()
tool_messages.append(
{
'role': 'assistant',
'content': None,
'tool_calls': (
choice.message['tool_calls']
if isinstance(choice.message, dict)
else (
getattr(
choice.message, 'tool_calls', None
)
or []
)
),
}
)
# Now request with response format
return self._request_parse(
tool_messages, response_format
)
else:
# No tools, directly use response format
return self._request_parse(messages, response_format)

# No response format, just do normal completion
return self._request_chat_completion(messages, tools)

async def _arun(
self,
Expand Down Expand Up @@ -161,13 +193,32 @@ async def _arun(
messages = self._process_messages(messages)
if response_format:
if tools:
raise ValueError(
"Gemini does not support function calling with "
"response format."
)
return await self._arequest_parse(messages, response_format)
else:
return await self._arequest_chat_completion(messages, tools)
# For tool calls, first get the tool response
completion = self._arequest_chat_completion(messages, tools)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use await for async function

if hasattr(completion, 'choices') and completion.choices:
choice = completion.choices[0]
if hasattr(choice, 'message') and choice.message.get(
'tool_calls'
):
# make a new request with response format
tool_messages = messages.copy()
tool_messages.append(
{
'role': 'assistant',
'content': None,
'tool_calls': choice.message['tool_calls'],
}
)
# Now request with response format
return await self._arequest_parse(
tool_messages, response_format
)
else:
# No tools, directly use response format
return await self._arequest_parse(messages, response_format)

# No response format, just do normal completion
return await self._arequest_chat_completion(messages, tools)

def _request_chat_completion(
self,
Expand Down Expand Up @@ -249,6 +300,48 @@ async def _arequest_chat_completion(
**request_config,
)

def _request_parse(
self,
messages: List[OpenAIMessage],
response_format: Type[BaseModel],
tools: Optional[List[Dict[str, Any]]] = None,
) -> ChatCompletion:
import copy

request_config = copy.deepcopy(self.model_config_dict)
request_config["response_format"] = {
"type": "json_object",
"schema": response_format.schema(),
}
request_config.pop("stream", None)

return self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**request_config,
)

async def _arequest_parse(
self,
messages: List[OpenAIMessage],
response_format: Type[BaseModel],
tools: Optional[List[Dict[str, Any]]] = None,
) -> ChatCompletion:
import copy

request_config = copy.deepcopy(self.model_config_dict)
request_config["response_format"] = {
"type": "json_object",
"schema": response_format.schema(),
}
request_config.pop("stream", None)

return await self._async_client.chat.completions.create(
messages=messages,
model=self.model_type,
**request_config,
)

def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to Gemini API.
Expand Down