Skip to content

Commit 8fa61f9

Browse files
authored
fix: handle thought_signature in parallel function calling (#3462)
1 parent 151d708 commit 8fa61f9

File tree

2 files changed

+239
-6
lines changed

2 files changed

+239
-6
lines changed

camel/models/gemini_model.py

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,19 +122,116 @@ def __init__(
122122
def _process_messages(self, messages) -> List[OpenAIMessage]:
123123
r"""Process the messages for Gemini API to ensure no empty content,
124124
which is not accepted by Gemini. Also preserves thought signatures
125-
required for Gemini 3 Pro function calling and adds fallback signatures
126-
when they are missing.
125+
required for Gemini 3 Pro function calling.
126+
127+
This method also merges consecutive assistant messages with single
128+
tool calls into a single assistant message with multiple tool calls,
129+
as required by Gemini's OpenAI-compatible API for parallel function
130+
calling.
127131
"""
128132
import copy
129133

130-
processed_messages = []
131-
for msg in messages:
132-
# Use deep copy to preserve all nested structures including
133-
# thought signatures in extra_content
134+
processed_messages: List[OpenAIMessage] = []
135+
i = 0
136+
n = len(messages)
137+
138+
while i < n:
139+
msg = messages[i]
140+
141+
# Check if this is an assistant message with a single tool_call
142+
# that might need to be merged with subsequent ones
143+
if (
144+
msg.get('role') == 'assistant'
145+
and 'tool_calls' in msg
146+
and isinstance(msg['tool_calls'], list)
147+
and len(msg['tool_calls']) == 1
148+
):
149+
# Look ahead to check if there are more assistant messages
150+
# with single tool calls (interleaved with their tool results)
151+
j = i + 1
152+
has_more_tool_calls = False
153+
154+
# Collect tool_call_ids we've seen so far
155+
first_tool_call_id = msg['tool_calls'][0].get('id')
156+
seen_tool_call_ids = (
157+
{first_tool_call_id} if first_tool_call_id else set()
158+
)
159+
160+
# Scan ahead to find pattern: tool_result, assistant,
161+
# tool_result, ...
162+
while j < n:
163+
next_msg = messages[j]
164+
next_role = next_msg.get('role')
165+
166+
if next_role == 'tool':
167+
# Tool result - check if it belongs to our batch
168+
if next_msg.get('tool_call_id') in seen_tool_call_ids:
169+
j += 1
170+
continue
171+
else:
172+
# Tool result for unknown call, stop scanning
173+
break
174+
elif (
175+
next_role == 'assistant'
176+
and 'tool_calls' in next_msg
177+
and isinstance(next_msg['tool_calls'], list)
178+
and len(next_msg['tool_calls']) == 1
179+
):
180+
# Another single tool call - mark for merging
181+
has_more_tool_calls = True
182+
tc_id = next_msg['tool_calls'][0].get('id')
183+
if tc_id:
184+
seen_tool_call_ids.add(tc_id)
185+
j += 1
186+
continue
187+
else:
188+
# Something else, stop scanning
189+
break
190+
191+
if has_more_tool_calls:
192+
# Need to merge: collect all tool calls and results
193+
merged_tool_calls = []
194+
tool_results = []
195+
is_first = True
196+
197+
for k in range(i, j):
198+
m = messages[k]
199+
if m.get('role') == 'assistant' and 'tool_calls' in m:
200+
tc = m['tool_calls'][0]
201+
if is_first:
202+
# Keep extra_content only on first tool call
203+
merged_tool_calls.append(copy.deepcopy(tc))
204+
is_first = False
205+
else:
206+
# Remove extra_content from subsequent tool
207+
# calls
208+
tc_copy = {
209+
k: v
210+
for k, v in tc.items()
211+
if k != 'extra_content'
212+
}
213+
merged_tool_calls.append(tc_copy)
214+
elif m.get('role') == 'tool':
215+
tool_results.append(copy.deepcopy(m))
216+
217+
# Build merged assistant message
218+
merged_msg = copy.deepcopy(msg)
219+
merged_msg['tool_calls'] = merged_tool_calls
220+
if 'content' in merged_msg and merged_msg['content'] == '':
221+
merged_msg['content'] = 'null'
222+
223+
processed_messages.append(merged_msg)
224+
processed_messages.extend(tool_results)
225+
i = j
226+
continue
227+
228+
# Regular message processing (no merging needed)
134229
msg_copy = copy.deepcopy(msg)
135230
if 'content' in msg_copy and msg_copy['content'] == '':
136231
msg_copy['content'] = 'null'
137232
processed_messages.append(msg_copy)
233+
i += 1
234+
138235
return processed_messages
139236

140237
def _preserve_thought_signatures(

test/models/test_gemini_model.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,139 @@ def test_gemini_model(model_type: ModelType):
4141
assert isinstance(model.token_counter, OpenAITokenCounter)
4242
assert isinstance(model.model_type.value_for_tiktoken, str)
4343
assert isinstance(model.model_type.token_limit, int)
44+
45+
46+
@pytest.mark.model_backend
47+
def test_gemini_process_messages_merges_parallel_tool_calls():
48+
r"""Test that _process_messages merges consecutive assistant messages with
49+
single tool calls into a single assistant message with multiple tool calls.
50+
51+
This is required for Gemini's OpenAI-compatible API for parallel function
52+
calling.
53+
"""
54+
model_config_dict = GeminiConfig().as_dict()
55+
model = GeminiModel(ModelType.GEMINI_3_PRO, model_config_dict)
56+
57+
# Simulate messages where ChatAgent recorded parallel tool calls as
58+
# separate assistant messages (the default behavior)
59+
messages = [
60+
{"role": "user", "content": "Calculate 2+2 and 3*3"},
61+
{
62+
"role": "assistant",
63+
"content": "",
64+
"tool_calls": [
65+
{
66+
"id": "call_001",
67+
"type": "function",
68+
"function": {
69+
"name": "math_add",
70+
"arguments": '{"a": 2, "b": 2}',
71+
},
72+
"extra_content": {
73+
"google": {"thought_signature": "sig_A"}
74+
},
75+
},
76+
],
77+
},
78+
{
79+
"role": "tool",
80+
"tool_call_id": "call_001",
81+
"content": "4",
82+
},
83+
{
84+
"role": "assistant",
85+
"content": "",
86+
"tool_calls": [
87+
{
88+
"id": "call_002",
89+
"type": "function",
90+
"function": {
91+
"name": "math_multiply",
92+
"arguments": '{"a": 3, "b": 3}',
93+
},
94+
"extra_content": {
95+
"google": {"thought_signature": "sig_B"}
96+
},
97+
},
98+
],
99+
},
100+
{
101+
"role": "tool",
102+
"tool_call_id": "call_002",
103+
"content": "9",
104+
},
105+
]
106+
107+
processed = model._process_messages(messages)
108+
109+
# Should have: user, assistant (merged), tool, tool
110+
assert len(processed) == 4
111+
112+
# First message is user
113+
assert processed[0]['role'] == 'user'
114+
115+
# Second message should be assistant with merged tool_calls
116+
assistant_msg = processed[1]
117+
assert assistant_msg['role'] == 'assistant'
118+
assert len(assistant_msg['tool_calls']) == 2
119+
120+
# First tool call should have extra_content
121+
assert 'extra_content' in assistant_msg['tool_calls'][0]
122+
assert (
123+
assistant_msg['tool_calls'][0]['extra_content']['google'][
124+
'thought_signature'
125+
]
126+
== 'sig_A'
127+
)
128+
129+
# Second tool call should NOT have extra_content (Gemini requirement)
130+
assert 'extra_content' not in assistant_msg['tool_calls'][1]
131+
132+
# Tool results should follow
133+
assert processed[2]['role'] == 'tool'
134+
assert processed[2]['tool_call_id'] == 'call_001'
135+
assert processed[3]['role'] == 'tool'
136+
assert processed[3]['tool_call_id'] == 'call_002'
137+
138+
139+
@pytest.mark.model_backend
140+
def test_gemini_process_messages_single_tool_call_unchanged():
141+
r"""Test that _process_messages preserves single tool calls unchanged."""
142+
model_config_dict = GeminiConfig().as_dict()
143+
model = GeminiModel(ModelType.GEMINI_3_PRO, model_config_dict)
144+
145+
messages = [
146+
{"role": "user", "content": "Calculate 2+2"},
147+
{
148+
"role": "assistant",
149+
"content": "",
150+
"tool_calls": [
151+
{
152+
"id": "call_001",
153+
"type": "function",
154+
"function": {
155+
"name": "math_add",
156+
"arguments": '{"a": 2, "b": 2}',
157+
},
158+
"extra_content": {
159+
"google": {"thought_signature": "sig_A"}
160+
},
161+
},
162+
],
163+
},
164+
{
165+
"role": "tool",
166+
"tool_call_id": "call_001",
167+
"content": "4",
168+
},
169+
]
170+
171+
processed = model._process_messages(messages)
172+
173+
# Should remain unchanged: user, assistant, tool
174+
assert len(processed) == 3
175+
176+
assistant_msg = processed[1]
177+
assert assistant_msg['role'] == 'assistant'
178+
assert len(assistant_msg['tool_calls']) == 1
179+
assert 'extra_content' in assistant_msg['tool_calls'][0]

0 commit comments

Comments
 (0)