Skip to content

Commit 72d7376

Browse files
committed
support gemini
1 parent 7970d89 commit 72d7376

File tree

1 file changed

+130
-103
lines changed

1 file changed

+130
-103
lines changed

camel/models/gemini_model.py

Lines changed: 130 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
)
2525

2626
from openai import AsyncStream, Stream
27+
from openai.lib.streaming.chat import (
28+
AsyncChatCompletionStreamManager,
29+
ChatCompletionStreamManager,
30+
)
2731
from pydantic import BaseModel
2832

2933
from camel.configs import GeminiConfig
@@ -451,6 +455,47 @@ async def async_thought_preserving_generator():
451455

452456
return async_thought_preserving_generator()
453457

458+
@staticmethod
459+
def _clean_gemini_tools(
460+
tools: Optional[List[Dict[str, Any]]],
461+
) -> Optional[List[Dict[str, Any]]]:
462+
r"""Clean tools for Gemini API compatibility.
463+
464+
Removes unsupported fields like strict, anyOf, and restricts
465+
enum/format to allowed types.
466+
"""
467+
if not tools:
468+
return tools
469+
import copy
470+
471+
tools = copy.deepcopy(tools)
472+
for tool in tools:
473+
function_dict = tool.get('function', {})
474+
function_dict.pop("strict", None)
475+
476+
if 'parameters' in function_dict:
477+
params = function_dict['parameters']
478+
if 'properties' in params:
479+
for prop_name, prop_value in params['properties'].items():
480+
if 'anyOf' in prop_value:
481+
first_type = prop_value['anyOf'][0]
482+
params['properties'][prop_name] = first_type
483+
if 'description' in prop_value:
484+
params['properties'][prop_name][
485+
'description'
486+
] = prop_value['description']
487+
488+
if prop_value.get('type') != 'string':
489+
prop_value.pop('enum', None)
490+
491+
if prop_value.get('type') not in [
492+
'string',
493+
'integer',
494+
'number',
495+
]:
496+
prop_value.pop('format', None)
497+
return tools
498+
454499
@observe()
455500
def _run(
456501
self,
@@ -479,19 +524,18 @@ def _run(
479524
"response_format", None
480525
)
481526
messages = self._process_messages(messages)
527+
is_streaming = self.model_config_dict.get("stream", False)
528+
482529
if response_format:
483-
if tools:
484-
raise ValueError(
485-
"Gemini does not support function calling with "
486-
"response format."
530+
tools = self._clean_gemini_tools(tools)
531+
if is_streaming:
532+
return self._request_stream_parse( # type: ignore[return-value]
533+
messages, response_format, tools
487534
)
488-
result: Union[ChatCompletion, Stream[ChatCompletionChunk]] = (
489-
self._request_parse(messages, response_format)
490-
)
535+
else:
536+
return self._request_parse(messages, response_format, tools)
491537
else:
492-
result = self._request_chat_completion(messages, tools)
493-
494-
return result
538+
return self._request_chat_completion(messages, tools)
495539

496540
@observe()
497541
async def _arun(
@@ -521,67 +565,90 @@ async def _arun(
521565
"response_format", None
522566
)
523567
messages = self._process_messages(messages)
568+
is_streaming = self.model_config_dict.get("stream", False)
569+
524570
if response_format:
525-
if tools:
526-
raise ValueError(
527-
"Gemini does not support function calling with "
528-
"response format."
571+
tools = self._clean_gemini_tools(tools)
572+
if is_streaming:
573+
return await self._arequest_stream_parse( # type: ignore[return-value]
574+
messages, response_format, tools
575+
)
576+
else:
577+
return await self._arequest_parse(
578+
messages, response_format, tools
529579
)
530-
result: Union[
531-
ChatCompletion, AsyncStream[ChatCompletionChunk]
532-
] = await self._arequest_parse(messages, response_format)
533580
else:
534-
result = await self._arequest_chat_completion(messages, tools)
581+
return await self._arequest_chat_completion(messages, tools)
582+
583+
@staticmethod
584+
def _build_gemini_response_format(
585+
response_format: Type[BaseModel],
586+
) -> Dict[str, Any]:
587+
r"""Convert a Pydantic model to Gemini-compatible response_format."""
588+
schema = response_format.model_json_schema()
589+
# Remove $defs and other unsupported fields for Gemini
590+
schema.pop("$defs", None)
591+
schema.pop("definitions", None)
592+
return {
593+
"type": "json_schema",
594+
"json_schema": {
595+
"name": response_format.__name__,
596+
"schema": schema,
597+
},
598+
}
599+
600+
def _request_stream_parse(
601+
self,
602+
messages: List[OpenAIMessage],
603+
response_format: Type[BaseModel],
604+
tools: Optional[List[Dict[str, Any]]] = None,
605+
) -> ChatCompletionStreamManager[BaseModel]:
606+
r"""Gemini-specific streaming structured output.
607+
608+
Uses regular streaming with response_format as JSON schema
609+
instead of OpenAI's beta streaming API which is incompatible
610+
with Gemini's tool call delta format.
611+
"""
612+
request_config = self._prepare_request_config(tools)
613+
request_config["stream"] = True
614+
request_config["response_format"] = self._build_gemini_response_format(
615+
response_format
616+
)
617+
618+
response = self._client.chat.completions.create(
619+
messages=messages,
620+
model=self.model_type,
621+
**request_config,
622+
)
623+
return self._preserve_thought_signatures(response) # type: ignore[return-value]
535624

536-
return result
625+
async def _arequest_stream_parse(
626+
self,
627+
messages: List[OpenAIMessage],
628+
response_format: Type[BaseModel],
629+
tools: Optional[List[Dict[str, Any]]] = None,
630+
) -> AsyncChatCompletionStreamManager[BaseModel]:
631+
r"""Gemini-specific async streaming structured output."""
632+
request_config = self._prepare_request_config(tools)
633+
request_config["stream"] = True
634+
request_config["response_format"] = self._build_gemini_response_format(
635+
response_format
636+
)
637+
638+
response = await self._async_client.chat.completions.create(
639+
messages=messages,
640+
model=self.model_type,
641+
**request_config,
642+
)
643+
return self._preserve_thought_signatures(response) # type: ignore[return-value]
537644

538645
def _request_chat_completion(
539646
self,
540647
messages: List[OpenAIMessage],
541648
tools: Optional[List[Dict[str, Any]]] = None,
542649
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
543-
import copy
544-
545-
request_config = copy.deepcopy(self.model_config_dict)
546-
# Remove strict and anyOf from each tool's function parameters since
547-
# Gemini does not support them
548-
if tools:
549-
for tool in tools:
550-
function_dict = tool.get('function', {})
551-
function_dict.pop("strict", None)
552-
553-
# Process parameters to remove anyOf and handle enum/format
554-
if 'parameters' in function_dict:
555-
params = function_dict['parameters']
556-
if 'properties' in params:
557-
for prop_name, prop_value in params[
558-
'properties'
559-
].items():
560-
if 'anyOf' in prop_value:
561-
# Replace anyOf with the first type in the list
562-
first_type = prop_value['anyOf'][0]
563-
params['properties'][prop_name] = first_type
564-
# Preserve description if it exists
565-
if 'description' in prop_value:
566-
params['properties'][prop_name][
567-
'description'
568-
] = prop_value['description']
569-
570-
# Handle enum and format restrictions for Gemini
571-
# API enum: only allowed for string type
572-
if prop_value.get('type') != 'string':
573-
prop_value.pop('enum', None)
574-
575-
# format: only allowed for string, integer, and
576-
# number types
577-
if prop_value.get('type') not in [
578-
'string',
579-
'integer',
580-
'number',
581-
]:
582-
prop_value.pop('format', None)
583-
584-
request_config["tools"] = tools
650+
tools = self._clean_gemini_tools(tools)
651+
request_config = self._prepare_request_config(tools)
585652

586653
response = self._client.chat.completions.create(
587654
messages=messages,
@@ -597,48 +664,8 @@ async def _arequest_chat_completion(
597664
messages: List[OpenAIMessage],
598665
tools: Optional[List[Dict[str, Any]]] = None,
599666
) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]:
600-
import copy
601-
602-
request_config = copy.deepcopy(self.model_config_dict)
603-
# Remove strict and anyOf from each tool's function parameters since
604-
# Gemini does not support them
605-
if tools:
606-
for tool in tools:
607-
function_dict = tool.get('function', {})
608-
function_dict.pop("strict", None)
609-
610-
# Process parameters to remove anyOf and handle enum/format
611-
if 'parameters' in function_dict:
612-
params = function_dict['parameters']
613-
if 'properties' in params:
614-
for prop_name, prop_value in params[
615-
'properties'
616-
].items():
617-
if 'anyOf' in prop_value:
618-
# Replace anyOf with the first type in the list
619-
first_type = prop_value['anyOf'][0]
620-
params['properties'][prop_name] = first_type
621-
# Preserve description if it exists
622-
if 'description' in prop_value:
623-
params['properties'][prop_name][
624-
'description'
625-
] = prop_value['description']
626-
627-
# Handle enum and format restrictions for Gemini
628-
# API enum: only allowed for string type
629-
if prop_value.get('type') != 'string':
630-
prop_value.pop('enum', None)
631-
632-
# format: only allowed for string, integer, and
633-
# number types
634-
if prop_value.get('type') not in [
635-
'string',
636-
'integer',
637-
'number',
638-
]:
639-
prop_value.pop('format', None)
640-
641-
request_config["tools"] = tools
667+
tools = self._clean_gemini_tools(tools)
668+
request_config = self._prepare_request_config(tools)
642669

643670
response = await self._async_client.chat.completions.create(
644671
messages=messages,

0 commit comments

Comments
 (0)