Skip to content

Commit 9583a12

Browse files
committed
enhance: gemini stream mode support
1 parent 7c2edc8 commit 9583a12

File tree

4 files changed

+306
-44
lines changed

4 files changed

+306
-44
lines changed

camel/agents/_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ToolCallRequest(BaseModel):
2727
tool_name: str
2828
args: Dict[str, Any]
2929
tool_call_id: str
30+
extra_content: Optional[Dict[str, Any]] = None
3031

3132

3233
class ModelResponse(BaseModel):

camel/agents/chat_agent.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3672,8 +3672,13 @@ def _handle_batch_response(
36723672
tool_name = tool_call.function.name # type: ignore[union-attr]
36733673
tool_call_id = tool_call.id
36743674
args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
3675+
extra_content = getattr(tool_call, 'extra_content', None)
3676+
36753677
tool_call_request = ToolCallRequest(
3676-
tool_name=tool_name, args=args, tool_call_id=tool_call_id
3678+
tool_name=tool_name,
3679+
args=args,
3680+
tool_call_id=tool_call_id,
3681+
extra_content=extra_content,
36773682
)
36783683
tool_call_requests.append(tool_call_request)
36793684

@@ -3766,7 +3771,12 @@ def _execute_tool(
37663771
logger.warning(f"{error_msg} with result: {result}")
37673772

37683773
return self._record_tool_calling(
3769-
func_name, args, result, tool_call_id, mask_output=mask_flag
3774+
func_name,
3775+
args,
3776+
result,
3777+
tool_call_id,
3778+
mask_output=mask_flag,
3779+
extra_content=tool_call_request.extra_content,
37703780
)
37713781

37723782
async def _aexecute_tool(
@@ -3808,7 +3818,13 @@ async def _aexecute_tool(
38083818
error_msg = f"Error executing async tool '{func_name}': {e!s}"
38093819
result = f"Tool execution failed: {error_msg}"
38103820
logger.warning(error_msg)
3811-
return self._record_tool_calling(func_name, args, result, tool_call_id)
3821+
return self._record_tool_calling(
3822+
func_name,
3823+
args,
3824+
result,
3825+
tool_call_id,
3826+
extra_content=tool_call_request.extra_content,
3827+
)
38123828

38133829
def _record_tool_calling(
38143830
self,
@@ -3817,6 +3833,7 @@ def _record_tool_calling(
38173833
result: Any,
38183834
tool_call_id: str,
38193835
mask_output: bool = False,
3836+
extra_content: Optional[Dict[str, Any]] = None,
38203837
):
38213838
r"""Record the tool calling information in the memory, and return the
38223839
tool calling record.
@@ -3829,6 +3846,9 @@ def _record_tool_calling(
38293846
mask_output (bool, optional): Whether to return a sanitized
38303847
placeholder instead of the raw tool output.
38313848
(default: :obj:`False`)
3849+
extra_content (Optional[Dict[str, Any]], optional): Additional
3850+
content associated with the tool call.
3851+
(default: :obj:`None`)
38323852
38333853
Returns:
38343854
ToolCallingRecord: A struct containing information about
@@ -3842,6 +3862,7 @@ def _record_tool_calling(
38423862
func_name=func_name,
38433863
args=args,
38443864
tool_call_id=tool_call_id,
3865+
extra_content=extra_content,
38453866
)
38463867
func_msg = FunctionCallingMessage(
38473868
role_name=self.role_name,
@@ -3852,6 +3873,7 @@ def _record_tool_calling(
38523873
result=result,
38533874
tool_call_id=tool_call_id,
38543875
mask_output=mask_output,
3876+
extra_content=extra_content,
38553877
)
38563878

38573879
# Use precise timestamps to ensure correct ordering
@@ -3986,7 +4008,7 @@ def _stream_response(
39864008
return
39874009

39884010
# Handle streaming response
3989-
if isinstance(response, Stream):
4011+
if isinstance(response, Stream) or inspect.isgenerator(response):
39904012
(
39914013
stream_completed,
39924014
tool_calls_complete,
@@ -4283,6 +4305,7 @@ def _accumulate_tool_calls(
42834305
'id': '',
42844306
'type': 'function',
42854307
'function': {'name': '', 'arguments': ''},
4308+
'extra_content': None,
42864309
'complete': False,
42874310
}
42884311

@@ -4306,6 +4329,14 @@ def _accumulate_tool_calls(
43064329
tool_call_entry['function']['arguments'] += (
43074330
delta_tool_call.function.arguments
43084331
)
4332+
# Handle extra_content if present
4333+
if (
4334+
hasattr(delta_tool_call, 'extra_content')
4335+
and delta_tool_call.extra_content
4336+
):
4337+
tool_call_entry['extra_content'] = (
4338+
delta_tool_call.extra_content
4339+
)
43094340

43104341
# Check if any tool calls are complete
43114342
any_complete = False
@@ -4410,6 +4441,7 @@ def _execute_tool_from_stream_data(
44104441
function_name = tool_call_data['function']['name']
44114442
args = json.loads(tool_call_data['function']['arguments'])
44124443
tool_call_id = tool_call_data['id']
4444+
extra_content = tool_call_data.get('extra_content')
44134445

44144446
if function_name in self._internal_tools:
44154447
tool = self._internal_tools[function_name]
@@ -4425,6 +4457,7 @@ def _execute_tool_from_stream_data(
44254457
func_name=function_name,
44264458
args=args,
44274459
tool_call_id=tool_call_id,
4460+
extra_content=extra_content,
44284461
)
44294462

44304463
# Then create the tool response message
@@ -4436,6 +4469,7 @@ def _execute_tool_from_stream_data(
44364469
func_name=function_name,
44374470
result=result,
44384471
tool_call_id=tool_call_id,
4472+
extra_content=extra_content,
44394473
)
44404474

44414475
# Record both messages with precise timestamps to ensure
@@ -4481,6 +4515,7 @@ def _execute_tool_from_stream_data(
44814515
func_name=function_name,
44824516
result=result,
44834517
tool_call_id=tool_call_id,
4518+
extra_content=extra_content,
44844519
)
44854520

44864521
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
@@ -4512,6 +4547,7 @@ async def _aexecute_tool_from_stream_data(
45124547
function_name = tool_call_data['function']['name']
45134548
args = json.loads(tool_call_data['function']['arguments'])
45144549
tool_call_id = tool_call_data['id']
4550+
extra_content = tool_call_data.get('extra_content')
45154551

45164552
if function_name in self._internal_tools:
45174553
# Create the tool call message
@@ -4523,6 +4559,7 @@ async def _aexecute_tool_from_stream_data(
45234559
func_name=function_name,
45244560
args=args,
45254561
tool_call_id=tool_call_id,
4562+
extra_content=extra_content,
45264563
)
45274564
assist_ts = time.time_ns() / 1_000_000_000
45284565
self.update_memory(
@@ -4569,6 +4606,7 @@ async def _aexecute_tool_from_stream_data(
45694606
func_name=function_name,
45704607
result=result,
45714608
tool_call_id=tool_call_id,
4609+
extra_content=extra_content,
45724610
)
45734611
func_ts = time.time_ns() / 1_000_000_000
45744612
self.update_memory(
@@ -4602,6 +4640,7 @@ async def _aexecute_tool_from_stream_data(
46024640
func_name=function_name,
46034641
result=result,
46044642
tool_call_id=tool_call_id,
4643+
extra_content=extra_content,
46054644
)
46064645
func_ts = time.time_ns() / 1_000_000_000
46074646
self.update_memory(
@@ -4911,6 +4950,11 @@ def _record_assistant_tool_calls_message(
49114950
"arguments": tool_call_data["function"]["arguments"],
49124951
},
49134952
}
4953+
# Include extra_content if present
4954+
if tool_call_data.get('extra_content'):
4955+
tool_call_dict["extra_content"] = tool_call_data[
4956+
"extra_content"
4957+
]
49144958
tool_calls_list.append(tool_call_dict)
49154959

49164960
# Create an assistant message with tool calls

camel/messages/func_message.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,17 @@ class FunctionCallingMessage(BaseMessage):
5050
mask_output (Optional[bool]): Whether to return a sanitized placeholder
5151
instead of the raw tool output.
5252
(default: :obj:`False`)
53+
extra_content (Optional[Dict[str, Any]]): Additional content
54+
associated with the tool call.
55+
(default: :obj:`None`)
5356
"""
5457

5558
func_name: Optional[str] = None
5659
args: Optional[Dict] = None
5760
result: Optional[Any] = None
5861
tool_call_id: Optional[str] = None
5962
mask_output: Optional[bool] = False
63+
extra_content: Optional[Dict[str, Any]] = None
6064

6165
def to_openai_message(
6266
self,
@@ -131,19 +135,23 @@ def to_openai_assistant_message(self) -> OpenAIAssistantMessage:
131135
" due to missing function name or arguments."
132136
)
133137

138+
tool_call = {
139+
"id": self.tool_call_id or "null",
140+
"type": "function",
141+
"function": {
142+
"name": self.func_name,
143+
"arguments": json.dumps(self.args, ensure_ascii=False),
144+
},
145+
}
146+
147+
# Include extra_content if available
148+
if self.extra_content is not None:
149+
tool_call["extra_content"] = self.extra_content
150+
134151
return {
135152
"role": "assistant",
136153
"content": self.content or "",
137-
"tool_calls": [
138-
{
139-
"id": self.tool_call_id or "null",
140-
"type": "function",
141-
"function": {
142-
"name": self.func_name,
143-
"arguments": json.dumps(self.args, ensure_ascii=False),
144-
},
145-
}
146-
],
154+
"tool_calls": [tool_call], # type: ignore[list-item]
147155
}
148156

149157
def to_openai_tool_message(self) -> OpenAIToolMessageParam:
@@ -187,4 +195,6 @@ def to_dict(self) -> Dict:
187195
if self.tool_call_id is not None:
188196
base["tool_call_id"] = self.tool_call_id
189197
base["mask_output"] = self.mask_output
198+
if self.extra_content is not None:
199+
base["extra_content"] = self.extra_content
190200
return base

0 commit comments

Comments
 (0)