Skip to content

Commit a7e9fcb

Browse files
fix gemini default_api tool calls
1 parent 97cabc7 commit a7e9fcb

7 files changed

Lines changed: 713 additions & 42 deletions

File tree

crates/rig-core/src/agent/prompt_request/mod.rs

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use crate::{
1111
json_utils,
1212
memory::ConversationMemory,
1313
message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
14-
tool::server::ToolServerHandle,
14+
tool::{
15+
ToolSetError,
16+
server::{ToolServerError, ToolServerHandle},
17+
},
1518
wasm_compat::{WasmBoxedFuture, WasmCompatSend},
1619
};
1720
use futures::{StreamExt, stream};
@@ -732,6 +735,17 @@ where
732735
let output = match tool_server_handle.call_tool(tool_name, &args).await
733736
{
734737
Ok(res) => res,
738+
Err(ToolServerError::ToolsetError(
739+
ToolSetError::ToolNotFoundError(name),
740+
)) => {
741+
tracing::warn!(
742+
tool_name = name.as_str(),
743+
"Model requested an unknown tool"
744+
);
745+
return Err(PromptError::ToolError(
746+
ToolSetError::ToolNotFoundError(name),
747+
));
748+
}
735749
Err(e) => {
736750
tracing::warn!("Error while executing tool: {e}");
737751
e.to_string()
@@ -1015,8 +1029,10 @@ mod tests {
10151029
},
10161030
message::{Text, UserContent},
10171031
test_utils::{
1018-
AppendFailingMemory, CountingMemory, FailingMemory, MockCompletionModel, MockTurn,
1032+
AppendFailingMemory, CountingMemory, FailingMemory, MockCompletionModel,
1033+
MockStringOutputTool, MockTurn,
10191034
},
1035+
tool::ToolSetError,
10201036
};
10211037
use schemars::JsonSchema;
10221038
use serde::{Deserialize, Serialize};
@@ -1229,12 +1245,12 @@ mod tests {
12291245
reasoning_tokens: 0,
12301246
};
12311247
let model = MockCompletionModel::new([
1232-
MockTurn::tool_call("tool_call_1", "missing_tool", json!({"input": "value"}))
1248+
MockTurn::tool_call("tool_call_1", "string_output", json!({"input": "value"}))
12331249
.with_call_id("call_1")
12341250
.with_usage(first_call_usage),
12351251
MockTurn::text("").with_usage(second_call_usage),
12361252
]);
1237-
let agent = AgentBuilder::new(model).build();
1253+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
12381254

12391255
let response = agent
12401256
.prompt("do tool work")
@@ -1308,6 +1324,30 @@ mod tests {
13081324
validate_follow_up_tool_history(&requests[1]);
13091325
}
13101326

1327+
#[tokio::test]
1328+
async fn prompt_request_fails_fast_on_unknown_tool_call() {
1329+
let model = MockCompletionModel::new([
1330+
MockTurn::tool_call("tool_call_1", "default_api", json!({"input": "value"}))
1331+
.with_call_id("call_1"),
1332+
MockTurn::text("should not be called"),
1333+
]);
1334+
let recorded = model.clone();
1335+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
1336+
1337+
let err = agent
1338+
.prompt("do tool work")
1339+
.max_turns(3)
1340+
.await
1341+
.expect_err("unknown tool should fail");
1342+
1343+
assert!(matches!(
1344+
err,
1345+
PromptError::ToolError(ToolSetError::ToolNotFoundError(name))
1346+
if name == "default_api"
1347+
));
1348+
assert_eq!(recorded.requests().len(), 1);
1349+
}
1350+
13111351
#[tokio::test]
13121352
async fn prompt_request_concatenates_text_blocks_without_inserted_newlines() {
13131353
let model = MockCompletionModel::new([MockTurn::from_contents([

crates/rig-core/src/agent/prompt_request/streaming.rs

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
memory::ConversationMemory,
88
message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent},
99
streaming::{StreamedAssistantContent, StreamedUserContent},
10-
tool::server::ToolServerHandle,
10+
tool::server::{ToolServerError, ToolServerHandle},
1111
wasm_compat::{WasmBoxedFuture, WasmCompatSend},
1212
};
1313
use futures::{Stream, StreamExt};
@@ -707,6 +707,17 @@ where
707707
let tool_result = match
708708
tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await {
709709
Ok(thing) => thing,
710+
Err(ToolServerError::ToolsetError(
711+
ToolSetError::ToolNotFoundError(name),
712+
)) => {
713+
tracing::warn!(
714+
tool_name = name.as_str(),
715+
"Model requested an unknown tool"
716+
);
717+
return Err(StreamingError::Tool(
718+
ToolSetError::ToolNotFoundError(name),
719+
));
720+
}
710721
Err(e) => {
711722
tracing::warn!("Error while calling tool: {e}");
712723
e.to_string()
@@ -989,6 +1000,7 @@ mod tests {
9891000
use crate::streaming::{StreamingPrompt, ToolCallDeltaContent};
9901001
use crate::test_utils::{
9911002
AppendFailingMemory, FailingMemory, MockCompletionModel, MockResponse, MockStreamEvent,
1003+
MockStringOutputTool,
9921004
};
9931005
use futures::StreamExt;
9941006
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
@@ -1216,7 +1228,7 @@ mod tests {
12161228
vec![
12171229
MockStreamEvent::tool_call(
12181230
"tool_call_1",
1219-
"missing_tool",
1231+
"string_output",
12201232
serde_json::json!({"input": "value"}),
12211233
)
12221234
.with_call_id("call_1"),
@@ -1360,7 +1372,7 @@ mod tests {
13601372
MockStreamEvent::text("I need a tool. "),
13611373
MockStreamEvent::tool_call(
13621374
"tool_call_1",
1363-
"missing_tool",
1375+
"string_output",
13641376
serde_json::json!({"input": "value"}),
13651377
)
13661378
.with_call_id("call_1"),
@@ -1485,7 +1497,7 @@ mod tests {
14851497
async fn stream_prompt_continues_after_tool_call_turn() {
14861498
let model = streaming_tool_then_text_model();
14871499
let recorded = model.clone();
1488-
let agent = AgentBuilder::new(model).build();
1500+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
14891501
let empty_history: &[Message] = &[];
14901502

14911503
let mut stream = agent
@@ -1547,6 +1559,65 @@ mod tests {
15471559
assert!(validate_follow_up_tool_history(&requests[1]).is_ok());
15481560
}
15491561

1562+
#[tokio::test]
1563+
async fn stream_prompt_fails_fast_on_unknown_tool_call() {
1564+
let model = MockCompletionModel::from_stream_turns([
1565+
vec![
1566+
MockStreamEvent::tool_call(
1567+
"tool_call_1",
1568+
"default_api",
1569+
serde_json::json!({"input": "value"}),
1570+
)
1571+
.with_call_id("call_1"),
1572+
MockStreamEvent::final_response_with_total_tokens(4),
1573+
],
1574+
vec![
1575+
MockStreamEvent::text("should not be called"),
1576+
MockStreamEvent::final_response_with_total_tokens(6),
1577+
],
1578+
]);
1579+
let recorded = model.clone();
1580+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
1581+
let empty_history: &[Message] = &[];
1582+
1583+
let mut stream = agent
1584+
.stream_prompt("do tool work")
1585+
.with_history(empty_history)
1586+
.multi_turn(3)
1587+
.await;
1588+
let mut saw_tool_result = false;
1589+
let mut saw_final_response = false;
1590+
let mut error_message = None;
1591+
1592+
while let Some(item) = stream.next().await {
1593+
match item {
1594+
Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult {
1595+
..
1596+
})) => {
1597+
saw_tool_result = true;
1598+
}
1599+
Ok(MultiTurnStreamItem::FinalResponse(_)) => {
1600+
saw_final_response = true;
1601+
}
1602+
Err(err) => {
1603+
error_message = Some(err.to_string());
1604+
break;
1605+
}
1606+
Ok(_) => {}
1607+
}
1608+
}
1609+
1610+
assert!(!saw_tool_result);
1611+
assert!(!saw_final_response);
1612+
assert!(
1613+
error_message
1614+
.as_deref()
1615+
.is_some_and(|message| message.contains("ToolNotFoundError: default_api")),
1616+
"expected unknown-tool error, got {error_message:?}"
1617+
);
1618+
assert_eq!(recorded.requests().len(), 1);
1619+
}
1620+
15501621
#[tokio::test]
15511622
async fn stream_prompt_emits_tool_call_deltas_without_hook() {
15521623
let model = MockCompletionModel::from_stream_turns([[
@@ -1741,7 +1812,7 @@ mod tests {
17411812
vec![
17421813
MockStreamEvent::tool_call(
17431814
"tool_call_1",
1744-
"missing_tool",
1815+
"string_output",
17451816
serde_json::json!({"input": "value"}),
17461817
)
17471818
.with_call_id("call_1"),
@@ -1752,7 +1823,7 @@ mod tests {
17521823
MockStreamEvent::final_response(second_call_usage),
17531824
],
17541825
]);
1755-
let agent = AgentBuilder::new(model).build();
1826+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
17561827
let empty_history: &[Message] = &[];
17571828

17581829
let mut stream = agent
@@ -1852,7 +1923,7 @@ mod tests {
18521923
vec![
18531924
MockStreamEvent::tool_call(
18541925
"tool_call_1",
1855-
"missing_tool",
1926+
"string_output",
18561927
serde_json::json!({"input": "value"}),
18571928
)
18581929
.with_call_id("call_1"),
@@ -1862,7 +1933,7 @@ mod tests {
18621933
MockStreamEvent::final_response(second_call_usage),
18631934
],
18641935
]);
1865-
let agent = AgentBuilder::new(model).build();
1936+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
18661937
let empty_history: &[Message] = &[];
18671938

18681939
let mut stream = agent
@@ -1996,7 +2067,7 @@ mod tests {
19962067
async fn tool_follow_up_history_preserves_structured_text_metadata() {
19972068
let model = streaming_cited_text_then_tool_model();
19982069
let recorded = model.clone();
1999-
let agent = AgentBuilder::new(model).build();
2070+
let agent = AgentBuilder::new(model).tool(MockStringOutputTool).build();
20002071
let empty_history: &[Message] = &[];
20012072

20022073
let mut stream = agent

0 commit comments

Comments
 (0)