diff --git a/crates/rig-core/src/agent/builder.rs b/crates/rig-core/src/agent/builder.rs index 6bdd8b77c..5d9026015 100644 --- a/crates/rig-core/src/agent/builder.rs +++ b/crates/rig-core/src/agent/builder.rs @@ -239,6 +239,34 @@ where self.default_conversation_id = Some(id.into()); self } + + /// Set the default hook for the agent. + /// + /// This hook will be used for all prompt requests unless overridden + /// via `.with_hook()` on the request. + pub fn hook(self, hook: P2) -> AgentBuilder + where + P2: PromptHook, + { + AgentBuilder { + name: self.name, + description: self.description, + model: self.model, + preamble: self.preamble, + static_context: self.static_context, + additional_params: self.additional_params, + max_tokens: self.max_tokens, + dynamic_context: self.dynamic_context, + temperature: self.temperature, + tool_choice: self.tool_choice, + default_max_turns: self.default_max_turns, + tool_state: self.tool_state, + hook: Some(hook), + output_schema: self.output_schema, + memory: self.memory, + default_conversation_id: self.default_conversation_id, + } + } } impl AgentBuilder @@ -482,34 +510,6 @@ where } } - /// Set the default hook for the agent. - /// - /// This hook will be used for all prompt requests unless overridden - /// via `.with_hook()` on the request. - pub fn hook(self, hook: P2) -> AgentBuilder - where - P2: PromptHook, - { - AgentBuilder { - name: self.name, - description: self.description, - model: self.model, - preamble: self.preamble, - static_context: self.static_context, - additional_params: self.additional_params, - max_tokens: self.max_tokens, - dynamic_context: self.dynamic_context, - temperature: self.temperature, - tool_choice: self.tool_choice, - default_max_turns: self.default_max_turns, - tool_state: self.tool_state, - hook: Some(hook), - output_schema: self.output_schema, - memory: self.memory, - default_conversation_id: self.default_conversation_id, - } - } - /// Build the agent with no tools configured. /// /// An empty `ToolServer` will be created for the agent. @@ -651,3 +651,22 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{MockAddTool, MockCompletionModel}; + + #[derive(Clone)] + struct BuilderHook; + + impl PromptHook for BuilderHook {} + + #[test] + fn hook_can_be_set_after_tool_configuration() { + let _agent = AgentBuilder::new(MockCompletionModel::text("ok")) + .tool(MockAddTool) + .hook(BuilderHook) + .build(); + } +} diff --git a/crates/rig-core/src/agent/completion.rs b/crates/rig-core/src/agent/completion.rs index b7e0816e3..91dbd8c6f 100644 --- a/crates/rig-core/src/agent/completion.rs +++ b/crates/rig-core/src/agent/completion.rs @@ -11,7 +11,10 @@ use crate::{ vector_store::{VectorStoreError, request::VectorSearchRequest}, wasm_compat::WasmCompatSend, }; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{BTreeSet, HashMap}, + sync::Arc, +}; const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent"; @@ -22,8 +25,53 @@ pub type DynamicContextStore = Arc< )>, >; +/// A prepared completion request plus the executable Rig tool names advertised +/// to the provider for this turn. +pub(crate) struct PreparedCompletionRequest { + pub(crate) builder: CompletionRequestBuilder, + pub(crate) executable_tool_names: BTreeSet, + pub(crate) allowed_tool_names: BTreeSet, +} + +pub(crate) fn allowed_tool_names_for_choice( + executable_tool_names: &BTreeSet, + tool_choice: Option<&ToolChoice>, +) -> Result, CompletionError> { + let allowed = match tool_choice { + None | Some(ToolChoice::Auto | ToolChoice::Required) => executable_tool_names.clone(), + Some(ToolChoice::None) => BTreeSet::new(), + Some(ToolChoice::Specific { function_names }) => { + if function_names.is_empty() { + return Err(CompletionError::RequestError( + "ToolChoice::Specific requires at least one function name".into(), + )); + } + + let requested = function_names.iter().cloned().collect::>(); + let missing = requested + .difference(executable_tool_names) + .cloned() + .collect::>(); + + if !missing.is_empty() { + return Err(CompletionError::RequestError( + format!( + "ToolChoice::Specific requested unknown tool names: {missing:?}. Available tools: {:?}", + executable_tool_names.iter().collect::>() + ) + .into(), + )); + } + + requested + } + }; + + Ok(allowed) +} + /// Helper function to build a completion request from agent components. -/// This is used by both `Agent::completion()` and `PromptRequest::send()`. +/// This is used by `Agent::completion()` to preserve the public completion API. #[allow(clippy::too_many_arguments)] pub(crate) async fn build_completion_request( model: &Arc, @@ -39,6 +87,41 @@ pub(crate) async fn build_completion_request( dynamic_context: &DynamicContextStore, output_schema: Option<&schemars::Schema>, ) -> Result, CompletionError> { + Ok(build_prepared_completion_request( + model, + prompt, + chat_history, + preamble, + static_context, + temperature, + max_tokens, + additional_params, + tool_choice, + tool_server_handle, + dynamic_context, + output_schema, + ) + .await? + .builder) +} + +/// Helper function to build a completion request from agent components while +/// preserving the executable Rig tool names sent to the provider. +#[allow(clippy::too_many_arguments)] +pub(crate) async fn build_prepared_completion_request( + model: &Arc, + prompt: Message, + chat_history: &[Message], + preamble: Option<&str>, + static_context: &[Document], + temperature: Option, + max_tokens: Option, + additional_params: Option<&serde_json::Value>, + tool_choice: Option<&ToolChoice>, + tool_server_handle: &ToolServerHandle, + dynamic_context: &DynamicContextStore, + output_schema: Option<&schemars::Schema>, +) -> Result, CompletionError> { // Find the latest message in the chat history that contains RAG text let rag_text = prompt.rag_text(); let rag_text = rag_text.or_else(|| { @@ -73,7 +156,7 @@ pub(crate) async fn build_completion_request( }; // If the agent has RAG text, we need to fetch the dynamic context and tools - let result = match &rag_text { + let (builder, executable_tool_names) = match &rag_text { Some(text) => { // Map over the vector to create async tasks let search_futures = dynamic_context.iter().map(|(num_sample, index)| { @@ -123,21 +206,31 @@ pub(crate) async fn build_completion_request( .map_err(|_| { CompletionError::RequestError("Failed to get tool definitions".into()) })?; - - completion_request - .documents(fetched_context) - .tools(tooldefs) + let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect(); + + ( + completion_request + .documents(fetched_context) + .tools(tooldefs), + executable_tool_names, + ) } None => { let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| { CompletionError::RequestError("Failed to get tool definitions".into()) })?; + let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect(); - completion_request.tools(tooldefs) + (completion_request.tools(tooldefs), executable_tool_names) } }; + let allowed_tool_names = allowed_tool_names_for_choice(&executable_tool_names, tool_choice)?; - Ok(result) + Ok(PreparedCompletionRequest { + builder, + executable_tool_names, + allowed_tool_names, + }) } /// Struct representing an LLM agent. An agent is an LLM model combined with a preamble @@ -450,3 +543,95 @@ where TypedPromptRequest::from_agent(*self, prompt) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn tool_names(names: &[&str]) -> BTreeSet { + names.iter().map(|name| (*name).to_string()).collect() + } + + #[test] + fn allowed_tool_names_defaults_to_all_executable_tools() { + let executable = tool_names(&["add", "subtract"]); + + assert_eq!( + allowed_tool_names_for_choice(&executable, None).unwrap(), + executable + ); + } + + #[test] + fn allowed_tool_names_auto_and_required_allow_all_executable_tools() { + let executable = tool_names(&["add", "subtract"]); + + assert_eq!( + allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Auto)).unwrap(), + executable + ); + assert_eq!( + allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Required)).unwrap(), + executable + ); + } + + #[test] + fn allowed_tool_names_none_allows_no_tools() { + let executable = tool_names(&["add", "subtract"]); + + assert!( + allowed_tool_names_for_choice(&executable, Some(&ToolChoice::None)) + .unwrap() + .is_empty() + ); + } + + #[test] + fn allowed_tool_names_specific_allows_requested_executable_tools() { + let executable = tool_names(&["add", "subtract"]); + let choice = ToolChoice::Specific { + function_names: vec!["add".to_string()], + }; + + assert_eq!( + allowed_tool_names_for_choice(&executable, Some(&choice)).unwrap(), + tool_names(&["add"]) + ); + } + + #[test] + fn allowed_tool_names_specific_rejects_missing_tools() { + let executable = tool_names(&["add"]); + let choice = ToolChoice::Specific { + function_names: vec!["missing".to_string()], + }; + + let err = allowed_tool_names_for_choice(&executable, Some(&choice)) + .expect_err("missing specific tool should fail before provider request"); + + assert!(matches!( + err, + CompletionError::RequestError(err) + if err.to_string().contains("missing") + && err.to_string().contains("add") + )); + } + + #[test] + fn allowed_tool_names_specific_rejects_empty_names() { + let executable = tool_names(&["add"]); + let choice = ToolChoice::Specific { + function_names: vec![], + }; + + let err = allowed_tool_names_for_choice(&executable, Some(&choice)) + .expect_err("empty specific tool choice should fail before provider request"); + + assert!(matches!( + err, + CompletionError::RequestError(err) + if err.to_string().contains("requires at least one function name") + )); + } +} diff --git a/crates/rig-core/src/agent/prompt_request/mod.rs b/crates/rig-core/src/agent/prompt_request/mod.rs index 971317bce..cbcc66913 100644 --- a/crates/rig-core/src/agent/prompt_request/mod.rs +++ b/crates/rig-core/src/agent/prompt_request/mod.rs @@ -3,7 +3,7 @@ pub mod streaming; use super::{ Agent, - completion::{DynamicContextStore, build_completion_request}, + completion::{DynamicContextStore, build_prepared_completion_request}, }; use crate::{ OneOrMany, @@ -18,6 +18,7 @@ use futures::{StreamExt, stream}; use hooks::{HookAction, PromptHook, ToolCallHookAction}; use serde::{Deserialize, Serialize}; use std::{ + collections::BTreeSet, future::IntoFuture, marker::PhantomData, sync::{ @@ -408,6 +409,24 @@ fn build_full_history( input.iter().cloned().chain(new_messages).collect() } +pub(crate) fn validate_tool_call_name( + tool_name: &str, + executable_tool_names: &BTreeSet, + allowed_tool_names: &BTreeSet, + chat_history: Vec, +) -> Result<(), PromptError> { + if allowed_tool_names.contains(tool_name) { + return Ok(()); + } + + Err(PromptError::UnknownToolCall { + tool_name: tool_name.to_owned(), + available_tools: executable_tool_names.iter().cloned().collect(), + allowed_tools: allowed_tool_names.iter().cloned().collect(), + chat_history: Box::new(chat_history), + }) +} + fn is_empty_assistant_turn(choice: &OneOrMany) -> bool { choice.len() == 1 && matches!( @@ -558,7 +577,7 @@ where let history_for_request = build_history_for_request(chat_history.as_deref(), history_for_current_turn); - let resp = build_completion_request( + let prepared_request = build_prepared_completion_request( &self.model, prompt.clone(), &history_for_request, @@ -572,10 +591,15 @@ where &self.dynamic_context, self.output_schema.as_ref(), ) - .await? - .send() - .instrument(chat_span.clone()) .await?; + let executable_tool_names = prepared_request.executable_tool_names.clone(); + let allowed_tool_names = prepared_request.allowed_tool_names.clone(); + + let resp = prepared_request + .builder + .send() + .instrument(chat_span.clone()) + .await?; completion_calls.push(CompletionCall::from_reported_usage( completion_call_index, @@ -584,16 +608,6 @@ where completion_call_index += 1; usage += resp.usage; - if let Some(ref hook) = self.hook - && let HookAction::Terminate { reason } = - hook.on_completion_response(&prompt, &resp).await - { - return Err(PromptError::prompt_cancelled( - build_full_history(chat_history.as_deref(), new_messages), - reason, - )); - } - let tool_calls = resp .choice .iter() @@ -603,11 +617,45 @@ where // Some providers normalize textless terminal turns into a single empty text item // because the generic completion response cannot represent an empty choice. Treat // that sentinel as "no assistant output" so it does not pollute returned history. - if !is_empty_assistant_turn(&resp.choice) { - new_messages.push(Message::Assistant { + let assistant_response_message = + (!is_empty_assistant_turn(&resp.choice)).then(|| Message::Assistant { id: resp.message_id.clone(), content: resp.choice.clone(), }); + + if !tool_calls.is_empty() { + let mut diagnostic_messages = new_messages.clone(); + if let Some(message) = assistant_response_message.clone() { + diagnostic_messages.push(message); + } + + for choice in &tool_calls { + if let AssistantContent::ToolCall(tool_call) = choice { + validate_tool_call_name( + &tool_call.function.name, + &executable_tool_names, + &allowed_tool_names, + build_full_history( + chat_history.as_deref(), + diagnostic_messages.clone(), + ), + )?; + } + } + } + + if let Some(ref hook) = self.hook + && let HookAction::Terminate { reason } = + hook.on_completion_response(&prompt, &resp).await + { + return Err(PromptError::prompt_cancelled( + build_full_history(chat_history.as_deref(), new_messages), + reason, + )); + } + + if let Some(message) = assistant_response_message { + new_messages.push(message); } if tool_calls.is_empty() { @@ -1014,14 +1062,18 @@ where mod tests { use super::{CompletionCall, PromptResponse, TypedPromptResponse}; use crate::{ - agent::AgentBuilder, + agent::{ + AgentBuilder, + prompt_request::hooks::{HookAction, PromptHook, ToolCallHookAction}, + }, completion::{ - AssistantContent, CompletionError, CompletionRequest, Message, Prompt, PromptError, - TypedPrompt, Usage, + AssistantContent, CompletionError, CompletionModel, CompletionRequest, Message, Prompt, + PromptError, TypedPrompt, Usage, }, - message::{Text, UserContent}, + message::{Text, ToolChoice, UserContent}, test_utils::{ - AppendFailingMemory, CountingMemory, FailingMemory, MockCompletionModel, MockTurn, + AppendFailingMemory, CountingMemory, FailingMemory, MockAddTool, MockCompletionModel, + MockSubtractTool, MockTurn, }, }; use schemars::JsonSchema; @@ -1043,6 +1095,31 @@ mod tests { value: String, } + #[derive(Clone)] + struct PanicOnUnknownToolHook; + + impl PromptHook for PanicOnUnknownToolHook { + fn on_completion_response( + &self, + _prompt: &Message, + _response: &crate::completion::CompletionResponse< + ::Response, + >, + ) -> impl std::future::Future + Send { + async { panic!("unknown tool response should fail before response hooks run") } + } + + fn on_tool_call( + &self, + _tool_name: &str, + _tool_call_id: Option, + _internal_call_id: &str, + _args: &str, + ) -> impl std::future::Future + Send { + async { panic!("unknown tool call should fail before tool hooks run") } + } + } + fn usage(input_tokens: u64, output_tokens: u64) -> Usage { Usage { input_tokens, @@ -1220,6 +1297,182 @@ mod tests { )); } + fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool { + history.iter().any(|message| { + matches!( + message, + Message::Assistant { content, .. } + if content.iter().any(|item| matches!( + item, + AssistantContent::ToolCall(tool_call) + if tool_call.function.name == tool_name + )) + ) + }) + } + + #[tokio::test] + async fn unknown_tool_call_fails_before_non_streaming_second_request() { + let model = MockCompletionModel::new([ + MockTurn::tool_call("tool_call_1", "default_api", json!({"x": 1, "y": 2})), + MockTurn::text("should not be requested"), + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let err = agent + .prompt("use the tool") + .with_hook(PanicOnUnknownToolHook) + .max_turns(3) + .await + .expect_err("unknown model-emitted tool should fail"); + + match err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "default_api"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "default_api")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn disallowed_specific_tool_call_fails_before_non_streaming_second_request() { + let model = MockCompletionModel::new([ + MockTurn::tool_call("tool_call_1", "subtract", json!({"x": 3, "y": 1})), + MockTurn::text("should not be requested"), + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool(MockSubtractTool) + .tool_choice(ToolChoice::Specific { + function_names: vec!["add".to_string()], + }) + .build(); + + let err = agent + .prompt("use the allowed tool") + .with_hook(PanicOnUnknownToolHook) + .max_turns(3) + .await + .expect_err("disallowed model-emitted tool should fail"); + + match err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "subtract"); + assert_eq!( + available_tools, + vec!["add".to_string(), "subtract".to_string()] + ); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "subtract")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn tool_choice_none_rejects_non_streaming_tool_call() { + let model = MockCompletionModel::new([ + MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})), + MockTurn::text("should not be requested"), + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool_choice(ToolChoice::None) + .build(); + + let err = agent + .prompt("do not use tools") + .with_hook(PanicOnUnknownToolHook) + .max_turns(3) + .await + .expect_err("ToolChoice::None should reject returned tool calls"); + + match err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "add"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert!(allowed_tools.is_empty()); + assert!(history_contains_tool_call(&chat_history, "add")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn invalid_specific_tool_choice_fails_before_non_streaming_provider_request() { + let model = MockCompletionModel::text("should not be requested"); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool_choice(ToolChoice::Specific { + function_names: vec!["missing".to_string()], + }) + .build(); + + let err = agent + .prompt("use the missing tool") + .await + .expect_err("invalid ToolChoice::Specific should fail before provider request"); + + match err { + PromptError::CompletionError(CompletionError::RequestError(err)) => { + let msg = err.to_string(); + assert!(msg.contains("missing"), "got: {msg}"); + assert!(msg.contains("add"), "got: {msg}"); + } + other => panic!("expected CompletionError::RequestError, got {other:?}"), + } + assert_eq!(recorded.request_count(), 0); + } + + #[tokio::test] + async fn allowed_specific_tool_call_executes_normally() { + let model = MockCompletionModel::new([ + MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})), + MockTurn::text("done"), + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool_choice(ToolChoice::Specific { + function_names: vec!["add".to_string()], + }) + .build(); + + let response = agent + .prompt("use the allowed tool") + .max_turns(3) + .await + .expect("allowed specific tool should execute"); + + assert_eq!(response, "done"); + assert_eq!(recorded.request_count(), 2); + } + #[tokio::test] async fn prompt_request_stops_cleanly_on_empty_terminal_turn() { let first_call_usage = Usage { @@ -1241,12 +1494,12 @@ mod tests { reasoning_tokens: 0, }; let model = MockCompletionModel::new([ - MockTurn::tool_call("tool_call_1", "missing_tool", json!({"input": "value"})) + MockTurn::tool_call("tool_call_1", "add", json!({"x": 1, "y": 2})) .with_call_id("call_1") .with_usage(first_call_usage), MockTurn::text("").with_usage(second_call_usage), ]); - let agent = AgentBuilder::new(model).build(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); let response = agent .prompt("do tool work") diff --git a/crates/rig-core/src/agent/prompt_request/streaming.rs b/crates/rig-core/src/agent/prompt_request/streaming.rs index ad4bc3c39..db77a780f 100644 --- a/crates/rig-core/src/agent/prompt_request/streaming.rs +++ b/crates/rig-core/src/agent/prompt_request/streaming.rs @@ -1,18 +1,21 @@ use crate::{ OneOrMany, - agent::completion::{DynamicContextStore, build_completion_request}, - agent::prompt_request::{HookAction, hooks::PromptHook}, + agent::completion::{DynamicContextStore, build_prepared_completion_request}, + agent::prompt_request::{HookAction, hooks::PromptHook, validate_tool_call_name}, completion::{Document, GetTokenUsage}, json_utils, memory::ConversationMemory, - message::{AssistantContent, ToolChoice, ToolResult, ToolResultContent, UserContent}, - streaming::{StreamedAssistantContent, StreamedUserContent}, + message::{ + AssistantContent, ToolCall, ToolChoice, ToolFunction, ToolResult, ToolResultContent, + UserContent, + }, + streaming::{StreamedAssistantContent, StreamedUserContent, ToolCallDeltaContent}, tool::server::ToolServerHandle, wasm_compat::{WasmBoxedFuture, WasmCompatSend}, }; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; -use std::{pin::Pin, sync::Arc}; +use std::{collections::HashMap, pin::Pin, sync::Arc}; use tracing::info_span; use tracing_futures::Instrument; @@ -196,6 +199,52 @@ fn build_full_history( input.iter().cloned().chain(new_messages).collect() } +fn build_tool_call_validation_history( + chat_history: Option<&[Message]>, + new_messages: &[Message], + assistant_message_id: &Option, + final_turn_content: Option<&OneOrMany>, + text_delta_response: Option<&str>, + pending_tool_calls: &[(ToolCall, String)], + current_tool_call: Option, +) -> Vec { + let mut messages = new_messages.to_vec(); + + if let Some(final_turn_content) = final_turn_content + && !is_empty_assistant_choice(final_turn_content) + { + messages.push(Message::Assistant { + id: assistant_message_id.clone(), + content: final_turn_content.clone(), + }); + return build_full_history(chat_history, messages); + } + + let mut content_items = Vec::new(); + if let Some(text) = text_delta_response + && !text.is_empty() + { + content_items.push(AssistantContent::text(text.to_string())); + } + content_items.extend( + pending_tool_calls + .iter() + .map(|(tool_call, _)| AssistantContent::ToolCall(tool_call.clone())), + ); + if let Some(tool_call) = current_tool_call { + content_items.push(AssistantContent::ToolCall(tool_call)); + } + + if let Some(content) = OneOrMany::from_iter_optional(content_items) { + messages.push(Message::Assistant { + id: assistant_message_id.clone(), + content, + }); + } + + build_full_history(chat_history, messages) +} + /// Combine input history with new messages for building completion requests. fn build_history_for_request( chat_history: Option<&[Message]>, @@ -263,6 +312,26 @@ fn is_empty_assistant_choice(choice: &OneOrMany) -> bool { ) } +#[derive(Default)] +struct ToolCallDeltaState { + name_validated: bool, + buffered_arguments: Vec, +} + +fn pending_tool_call_delta_error( + states: &HashMap<(String, String), ToolCallDeltaState>, +) -> Option { + states + .iter() + .find(|(_, state)| !state.name_validated && !state.buffered_arguments.is_empty()) + .map(|((id, internal_call_id), state)| { + CompletionError::ResponseError(format!( + "streamed tool call arguments received before a validated tool name for id `{id}` and internal_call_id `{internal_call_id}` ({} buffered argument delta(s))", + state.buffered_arguments.len() + )) + }) +} + #[derive(Debug, thiserror::Error)] pub enum StreamingError { #[error("CompletionError: {0}")] @@ -612,31 +681,35 @@ where gen_ai.output.messages = tracing::field::Empty, ); - let mut stream = tracing::Instrument::instrument( - build_completion_request( - &model, - current_prompt.clone(), - &history_snapshot, - preamble.as_deref(), - &static_context, - temperature, - max_tokens, - additional_params.as_ref(), - tool_choice.as_ref(), - &tool_server_handle, - &dynamic_context, - output_schema.as_ref(), - ) - .await? - .stream(), chat_stream_span + let prepared_request = build_prepared_completion_request( + &model, + current_prompt.clone(), + &history_snapshot, + preamble.as_deref(), + &static_context, + temperature, + max_tokens, + additional_params.as_ref(), + tool_choice.as_ref(), + &tool_server_handle, + &dynamic_context, + output_schema.as_ref(), ) - .await?; + let executable_tool_names = prepared_request.executable_tool_names.clone(); + let allowed_tool_names = prepared_request.allowed_tool_names.clone(); + + let mut stream = prepared_request + .builder + .stream() + .instrument(chat_stream_span) + .await?; let call_index = completion_call_index; completion_call_index += 1; let mut current_call_usage = None; let mut completion_call_emitted = false; + let mut pending_tool_calls: Vec<(ToolCall, String)> = vec![]; let mut tool_calls = vec![]; let mut tool_results = vec![]; let mut accumulated_reasoning: Vec = vec![]; @@ -644,6 +717,8 @@ where // signatures (e.g. Anthropic) never see unsigned blocks. let mut pending_reasoning_delta_text = String::new(); let mut pending_reasoning_delta_id: Option = None; + let mut tool_call_delta_states: HashMap<(String, String), ToolCallDeltaState> = + HashMap::new(); let mut saw_tool_call_this_turn = false; while let Some(content) = stream.next().await { @@ -663,112 +738,115 @@ where yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text))); }, Ok(StreamedAssistantContent::ToolCall { tool_call, internal_call_id }) => { - let tool_span = info_span!( - parent: tracing::Span::current(), - "execute_tool", - gen_ai.operation.name = "execute_tool", - gen_ai.tool.type = "function", - gen_ai.tool.name = tracing::field::Empty, - gen_ai.tool.call.id = tracing::field::Empty, - gen_ai.tool.call.arguments = tracing::field::Empty, - gen_ai.tool.call.result = tracing::field::Empty + let diagnostic_history = build_tool_call_validation_history( + chat_history.as_deref(), + &new_messages, + &stream.message_id, + None, + saw_text_this_turn.then_some(text_delta_response.as_str()), + &pending_tool_calls, + Some(tool_call.clone()), ); - yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() })); - - let tc_result = async { - let tool_span = tracing::Span::current(); - let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments); - if let Some(ref hook) = self.hook { - let action = hook - .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args) - .await; + if let Err(err) = validate_tool_call_name( + &tool_call.function.name, + &executable_tool_names, + &allowed_tool_names, + diagnostic_history, + ) { + yield Err(Box::new(err).into()); + break 'outer; + } - if let ToolCallHookAction::Terminate { reason } = action { - return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await); + pending_tool_calls.push((tool_call, internal_call_id)); + }, + Ok(StreamedAssistantContent::ToolCallDelta { + id, + internal_call_id, + content, + }) => { + let key = (id.clone(), internal_call_id.clone()); + let mut deltas_to_emit = Vec::new(); + + match content { + ToolCallDeltaContent::Name(name) => { + let diagnostic_tool_call = ToolCall::new( + id.clone(), + ToolFunction::new(name.clone(), serde_json::Value::Null), + ); + let diagnostic_history = build_tool_call_validation_history( + chat_history.as_deref(), + &new_messages, + &stream.message_id, + None, + None, + &[], + Some(diagnostic_tool_call), + ); + + if let Err(err) = validate_tool_call_name( + &name, + &executable_tool_names, + &allowed_tool_names, + diagnostic_history, + ) { + yield Err(Box::new(err).into()); + break 'outer; } - if let ToolCallHookAction::Skip { reason } = action { - // Tool execution rejected, return rejection message as tool result - tracing::info!( - tool_name = tool_call.function.name.as_str(), - reason = reason, - "Tool call rejected" - ); - let tool_call_msg = AssistantContent::ToolCall(tool_call.clone()); - tool_calls.push(tool_call_msg); - tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone())); - saw_tool_call_this_turn = true; - return Ok(reason); - } + let state = + tool_call_delta_states.entry(key.clone()).or_default(); + state.name_validated = true; + let buffered_arguments = + std::mem::take(&mut state.buffered_arguments); + + deltas_to_emit.push(ToolCallDeltaContent::Name(name)); + deltas_to_emit.extend( + buffered_arguments + .into_iter() + .map(ToolCallDeltaContent::Delta), + ); } - - tool_span.record("gen_ai.tool.name", &tool_call.function.name); - tool_span.record("gen_ai.tool.call.arguments", &tool_args); - - let tool_result = match - tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await { - Ok(thing) => thing, - Err(e) => { - tracing::warn!("Error while calling tool: {e}"); - e.to_string() - } - }; - - tool_span.record("gen_ai.tool.call.result", &tool_result); - - if let Some(ref hook) = self.hook && - let HookAction::Terminate { reason } = - hook.on_tool_result( - &tool_call.function.name, - tool_call.call_id.clone(), - &internal_call_id, - &tool_args, - &tool_result.to_string() - ) - .await { - return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await); + ToolCallDeltaContent::Delta(arguments) => { + let state = + tool_call_delta_states.entry(key.clone()).or_default(); + if state.name_validated { + deltas_to_emit.push(ToolCallDeltaContent::Delta(arguments)); + } else { + state.buffered_arguments.push(arguments); } - - let tool_call_msg = AssistantContent::ToolCall(tool_call.clone()); - - tool_calls.push(tool_call_msg); - tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone())); - - saw_tool_call_this_turn = true; - Ok(tool_result) - }.instrument(tool_span).await; - - match tc_result { - Ok(text) => { - let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) }; - yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id })); - } - Err(e) => { - yield Err(e); - break 'outer; } } - }, - Ok(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content }) => { - if let Some(ref hook) = self.hook { - let (name, delta) = match &content { - rig::streaming::ToolCallDeltaContent::Name(n) => { - (Some(n.as_str()), "") - } - rig::streaming::ToolCallDeltaContent::Delta(d) => { - (None, d.as_str()) - } - }; - if let HookAction::Terminate { reason } = hook.on_tool_call_delta(&id, &internal_call_id, name, delta) - .await { - yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await); - break 'outer; + for content in deltas_to_emit { + if let Some(ref hook) = self.hook { + let (name, delta) = match &content { + ToolCallDeltaContent::Name(n) => (Some(n.as_str()), ""), + ToolCallDeltaContent::Delta(d) => (None, d.as_str()), + }; + + if let HookAction::Terminate { reason } = hook + .on_tool_call_delta( + &id, + &internal_call_id, + name, + delta, + ) + .await + { + yield Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await); + break 'outer; + } } - } - yield Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::ToolCallDelta { id, internal_call_id, content })); + yield Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { + id: id.clone(), + internal_call_id: internal_call_id.clone(), + content, + }, + )); + } } Ok(StreamedAssistantContent::Reasoning(reasoning)) => { // Accumulate reasoning for inclusion in chat history with tool calls. @@ -788,6 +866,13 @@ where yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ReasoningDelta { reasoning, id })); }, Ok(StreamedAssistantContent::Final(final_resp)) => { + if let Some(err) = + pending_tool_call_delta_error(&tool_call_delta_states) + { + yield Err(err.into()); + break 'outer; + } + if let Some(usage) = final_resp.token_usage() { current_call_usage = reported_usage(usage); } @@ -817,6 +902,11 @@ where } } + if let Some(err) = pending_tool_call_delta_error(&tool_call_delta_states) { + yield Err(err.into()); + break 'outer; + } + if !completion_call_emitted { let completion_call = CompletionCall::new(call_index, current_call_usage); completion_calls.push(completion_call); @@ -838,6 +928,119 @@ where let turn_text_response = assistant_text_from_choice(&final_turn_content); tracing::Span::current().record("gen_ai.completion", &turn_text_response); + if !pending_tool_calls.is_empty() { + let diagnostic_history = build_tool_call_validation_history( + chat_history.as_deref(), + &new_messages, + &stream.message_id, + Some(&final_turn_content), + None, + &pending_tool_calls, + None, + ); + + for (tool_call, _) in &pending_tool_calls { + if let Err(err) = validate_tool_call_name( + &tool_call.function.name, + &executable_tool_names, + &allowed_tool_names, + diagnostic_history.clone(), + ) { + yield Err(Box::new(err).into()); + break 'outer; + } + } + + for (tool_call, internal_call_id) in pending_tool_calls { + let tool_span = info_span!( + parent: tracing::Span::current(), + "execute_tool", + gen_ai.operation.name = "execute_tool", + gen_ai.tool.type = "function", + gen_ai.tool.name = tracing::field::Empty, + gen_ai.tool.call.id = tracing::field::Empty, + gen_ai.tool.call.arguments = tracing::field::Empty, + gen_ai.tool.call.result = tracing::field::Empty + ); + + yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::ToolCall { tool_call: tool_call.clone(), internal_call_id: internal_call_id.clone() })); + + let tc_result = async { + let tool_span = tracing::Span::current(); + let tool_args = json_utils::value_to_json_string(&tool_call.function.arguments); + if let Some(ref hook) = self.hook { + let action = hook + .on_tool_call(&tool_call.function.name, tool_call.call_id.clone(), &internal_call_id, &tool_args) + .await; + + if let ToolCallHookAction::Terminate { reason } = action { + return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await); + } + + if let ToolCallHookAction::Skip { reason } = action { + // Tool execution rejected, return rejection message as tool result + tracing::info!( + tool_name = tool_call.function.name.as_str(), + reason = reason, + "Tool call rejected" + ); + let tool_call_msg = AssistantContent::ToolCall(tool_call.clone()); + tool_calls.push(tool_call_msg); + tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), reason.clone())); + saw_tool_call_this_turn = true; + return Ok(reason); + } + } + + tool_span.record("gen_ai.tool.name", &tool_call.function.name); + tool_span.record("gen_ai.tool.call.arguments", &tool_args); + + let tool_result = match + tool_server_handle.call_tool(&tool_call.function.name, &tool_args).await { + Ok(thing) => thing, + Err(e) => { + tracing::warn!("Error while calling tool: {e}"); + e.to_string() + } + }; + + tool_span.record("gen_ai.tool.call.result", &tool_result); + + if let Some(ref hook) = self.hook && + let HookAction::Terminate { reason } = + hook.on_tool_result( + &tool_call.function.name, + tool_call.call_id.clone(), + &internal_call_id, + &tool_args, + &tool_result.to_string() + ) + .await { + return Err(cancelled_prompt_error(chat_history.as_deref(), new_messages.clone(), reason).await); + } + + let tool_call_msg = AssistantContent::ToolCall(tool_call.clone()); + + tool_calls.push(tool_call_msg); + tool_results.push((tool_call.id.clone(), tool_call.call_id.clone(), tool_result.clone())); + + saw_tool_call_this_turn = true; + Ok(tool_result) + }.instrument(tool_span).await; + + match tc_result { + Ok(text) => { + let tr = ToolResult { id: tool_call.id, call_id: tool_call.call_id, content: ToolResultContent::from_tool_output(text) }; + yield Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult{ tool_result: tr, internal_call_id })); + } + Err(e) => { + yield Err(e); + break 'outer; + } + } + } + } + // Add text, reasoning, and tool calls to chat history. // OpenAI Responses API requires reasoning items to precede function_call items. if !tool_calls.is_empty() || !accumulated_reasoning.is_empty() { @@ -981,19 +1184,23 @@ pub async fn stream_to_stdout( mod tests { use super::*; use crate::agent::AgentBuilder; + use crate::agent::prompt_request::hooks::{PromptHook, ToolCallHookAction}; use crate::client::ProviderClient; use crate::client::completion::CompletionClient; - use crate::completion::{CompletionRequest, Usage}; + use crate::completion::{CompletionRequest, PromptError, ToolDefinition, Usage}; use crate::message::{ AssistantContent, DocumentSourceKind, ImageMediaType, Message, ReasoningContent, - ToolResultContent, UserContent, + ToolChoice, ToolResultContent, UserContent, }; use crate::providers::anthropic; use crate::streaming::{StreamingPrompt, ToolCallDeltaContent}; use crate::test_utils::{ - AppendFailingMemory, FailingMemory, MockCompletionModel, MockResponse, MockStreamEvent, + AppendFailingMemory, FailingMemory, MockAddTool, MockCompletionModel, MockResponse, + MockStreamEvent, MockSubtractTool, MockToolError, }; + use crate::tool::Tool; use futures::StreamExt; + use serde::Deserialize; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -1214,13 +1421,129 @@ mod tests { Ok(()) } + fn history_contains_tool_call(history: &[Message], tool_name: &str) -> bool { + history.iter().any(|message| { + matches!( + message, + Message::Assistant { content, .. } + if content.iter().any(|item| matches!( + item, + AssistantContent::ToolCall(tool_call) + if tool_call.function.name == tool_name + )) + ) + }) + } + + #[derive(Clone)] + struct PanicOnUnknownToolHook; + + impl PromptHook for PanicOnUnknownToolHook { + fn on_tool_call_delta( + &self, + _tool_call_id: &str, + _internal_call_id: &str, + _tool_name: Option<&str>, + _tool_call_delta: &str, + ) -> impl std::future::Future + Send { + async { panic!("unknown tool call delta should fail before delta hooks run") } + } + + fn on_tool_call( + &self, + _tool_name: &str, + _tool_call_id: Option, + _internal_call_id: &str, + _args: &str, + ) -> impl std::future::Future + Send { + async { panic!("unknown tool call should fail before tool hooks run") } + } + + fn on_stream_completion_response_finish( + &self, + _prompt: &Message, + _response: &MockResponse, + ) -> impl std::future::Future + Send { + async { panic!("unknown tool call should fail before stream finish hooks run") } + } + } + + #[derive(Clone)] + struct CountingAddTool { + calls: Arc, + } + + #[derive(Clone)] + struct CountingSubtractTool { + calls: Arc, + } + + #[derive(Deserialize)] + struct CountingOperationArgs { + x: i32, + y: i32, + } + + fn arithmetic_tool_definition(name: &str, description: &str) -> ToolDefinition { + ToolDefinition { + name: name.to_string(), + description: description.to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first operand" + }, + "y": { + "type": "number", + "description": "The second operand" + } + }, + "required": ["x", "y"], + }), + } + } + + impl Tool for CountingAddTool { + const NAME: &'static str = "add"; + type Error = MockToolError; + type Args = CountingOperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + arithmetic_tool_definition(Self::NAME, "Add x and y together") + } + + async fn call(&self, args: Self::Args) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(args.x + args.y) + } + } + + impl Tool for CountingSubtractTool { + const NAME: &'static str = "subtract"; + type Error = MockToolError; + type Args = CountingOperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + arithmetic_tool_definition(Self::NAME, "Subtract y from x") + } + + async fn call(&self, args: Self::Args) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(args.x - args.y) + } + } + fn streaming_tool_then_text_model() -> MockCompletionModel { MockCompletionModel::from_stream_turns([ vec![ MockStreamEvent::tool_call( "tool_call_1", - "missing_tool", - serde_json::json!({"input": "value"}), + "add", + serde_json::json!({"x": 1, "y": 2}), ) .with_call_id("call_1"), MockStreamEvent::final_response_with_total_tokens(4), @@ -1366,8 +1689,8 @@ mod tests { MockStreamEvent::text("I need a tool. "), MockStreamEvent::tool_call( "tool_call_1", - "missing_tool", - serde_json::json!({"input": "value"}), + "add", + serde_json::json!({"x": 1, "y": 2}), ) .with_call_id("call_1"), MockStreamEvent::final_response_with_total_tokens(4), @@ -1491,7 +1814,7 @@ mod tests { async fn stream_prompt_continues_after_tool_call_turn() { let model = streaming_tool_then_text_model(); let recorded = model.clone(); - let agent = AgentBuilder::new(model).build(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); let empty_history: &[Message] = &[]; let mut stream = agent @@ -1554,98 +1877,1044 @@ mod tests { } #[tokio::test] - async fn stream_prompt_emits_tool_call_deltas_without_hook() { - let model = MockCompletionModel::from_stream_turns([[ - MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "calculator"), - MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"), - MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"), - MockStreamEvent::final_response_with_total_tokens(3), - ]]); - let agent = AgentBuilder::new(model).build(); + async fn unknown_tool_call_fails_before_streaming_second_request() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call( + "tool_call_1", + "default_api", + serde_json::json!({"x": 1, "y": 2}), + ), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); - let mut stream = agent.stream_prompt("stream a tool call").await; - let mut deltas = Vec::new(); + let mut stream = agent + .stream_prompt("use the tool") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_tool_call = false; + let mut error = None; while let Some(item) = stream.next().await { match item { Ok(MultiTurnStreamItem::StreamAssistantItem( - StreamedAssistantContent::ToolCallDelta { - id, - internal_call_id, - content, - }, + StreamedAssistantContent::ToolCall { .. }, )) => { - deltas.push((id, internal_call_id, content)); + saw_tool_call = true; } - Ok(MultiTurnStreamItem::FinalResponse(_)) => break, Ok(_) => {} - Err(err) => panic!("unexpected streaming error: {err:?}"), + Err(err) => { + error = Some(err); + break; + } } } - assert_eq!( - deltas, - vec![ - ( - "tool_1".to_string(), - "internal_1".to_string(), - ToolCallDeltaContent::Name("calculator".to_string()) - ), - ( - "tool_1".to_string(), - "internal_1".to_string(), - ToolCallDeltaContent::Delta("{\"x\":".to_string()) - ), - ( - "tool_1".to_string(), - "internal_1".to_string(), - ToolCallDeltaContent::Delta("1}".to_string()) - ), - ] - ); + assert!(!saw_tool_call); + let error = error.expect("unknown model-emitted tool should fail"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "default_api"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "default_api")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); } #[tokio::test] - async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() { - let model = MockCompletionModel::from_stream_turns([[ - MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "calculator"), - MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"), - MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"), - MockStreamEvent::final_response_with_total_tokens(3), - ]]); - let hook = RecordingToolCallDeltaHook::default(); - let agent = AgentBuilder::new(model).build(); + async fn completed_unknown_tool_call_after_text_fails_before_finish_hook_or_later_emit() { + let add_calls = Arc::new(AtomicU32::new(0)); + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::text("thinking "), + MockStreamEvent::tool_call( + "tool_call_1", + "default_api", + serde_json::json!({"x": 1, "y": 2}), + ), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(CountingAddTool { + calls: add_calls.clone(), + }) + .build(); let mut stream = agent - .stream_prompt("stream a tool call") - .with_hook(hook.clone()) + .stream_prompt("use the tool") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) .await; - let mut stream_deltas = Vec::new(); + let mut saw_text = false; + let mut saw_completion_call = false; + let mut saw_final_response = false; + let mut saw_tool_call = false; + let mut saw_tool_result = false; + let mut error = None; while let Some(item) = stream.next().await { match item { + Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(_))) => { + saw_text = true; + } + Ok(MultiTurnStreamItem::CompletionCall(_)) => { + saw_completion_call = true; + } + Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Final( + _, + ))) + | Ok(MultiTurnStreamItem::FinalResponse(_)) => { + saw_final_response = true; + } Ok(MultiTurnStreamItem::StreamAssistantItem( - StreamedAssistantContent::ToolCallDelta { - id, - internal_call_id, - content, - }, + StreamedAssistantContent::ToolCall { .. }, )) => { - stream_deltas.push((id, internal_call_id, content)); + saw_tool_call = true; + } + Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult { + .. + })) => { + saw_tool_result = true; } - Ok(MultiTurnStreamItem::FinalResponse(_)) => break, Ok(_) => {} - Err(err) => panic!("unexpected streaming error: {err:?}"), + Err(err) => { + error = Some(err); + break; + } } } - assert_eq!( - hook.observed(), - vec![ - ( + assert!(saw_text); + assert!(!saw_completion_call); + assert!(!saw_final_response); + assert!(!saw_tool_call); + assert!(!saw_tool_result); + assert_eq!(add_calls.load(Ordering::SeqCst), 0); + let error = error.expect("completed unknown tool call should fail immediately"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "default_api"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "default_api")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn mixed_streaming_tool_calls_fail_before_any_tool_execution() { + let add_calls = Arc::new(AtomicU32::new(0)); + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call( + "tool_call_1", + "add", + serde_json::json!({"x": 1, "y": 2}), + ) + .with_call_id("call_1"), + MockStreamEvent::tool_call( + "tool_call_2", + "default_api", + serde_json::json!({"x": 3, "y": 4}), + ), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(CountingAddTool { + calls: add_calls.clone(), + }) + .build(); + + let mut stream = agent + .stream_prompt("use tools") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_completion_call = false; + let mut saw_tool_call = false; + let mut saw_tool_result = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::CompletionCall(_)) => { + saw_completion_call = true; + } + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { .. }, + )) => { + saw_tool_call = true; + } + Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult { + .. + })) => { + saw_tool_result = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_completion_call); + assert!(!saw_tool_call); + assert!(!saw_tool_result); + assert_eq!(add_calls.load(Ordering::SeqCst), 0); + let error = error.expect("mixed unknown streamed tool call should fail"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "default_api"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "default_api")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn multiple_valid_streaming_tool_calls_execute_after_batch_validation() { + let add_calls = Arc::new(AtomicU32::new(0)); + let subtract_calls = Arc::new(AtomicU32::new(0)); + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call( + "tool_call_1", + "add", + serde_json::json!({"x": 1, "y": 2}), + ) + .with_call_id("call_1"), + MockStreamEvent::tool_call( + "tool_call_2", + "subtract", + serde_json::json!({"x": 8, "y": 3}), + ) + .with_call_id("call_2"), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("done"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(CountingAddTool { + calls: add_calls.clone(), + }) + .tool(CountingSubtractTool { + calls: subtract_calls.clone(), + }) + .build(); + + let mut stream = agent.stream_prompt("use tools").multi_turn(3).await; + let mut tool_call_names = Vec::new(); + let mut tool_result_ids = Vec::new(); + let mut final_response_text = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { tool_call, .. }, + )) => { + tool_call_names.push(tool_call.function.name); + } + Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult { + tool_result, + .. + })) => { + tool_result_ids.push(tool_result.id); + } + Ok(MultiTurnStreamItem::FinalResponse(response)) => { + final_response_text = Some(response.response().to_owned()); + break; + } + Ok(_) => {} + Err(err) => panic!("unexpected streaming error: {err:?}"), + } + } + + assert_eq!( + tool_call_names, + vec!["add".to_string(), "subtract".to_string()] + ); + assert_eq!( + tool_result_ids, + vec!["tool_call_1".to_string(), "tool_call_2".to_string()] + ); + assert_eq!(add_calls.load(Ordering::SeqCst), 1); + assert_eq!(subtract_calls.load(Ordering::SeqCst), 1); + assert_eq!(final_response_text.as_deref(), Some("done")); + assert_eq!(recorded.request_count(), 2); + } + + #[tokio::test] + async fn disallowed_specific_tool_call_fails_before_streaming_second_request() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call( + "tool_call_1", + "subtract", + serde_json::json!({"x": 3, "y": 1}), + ), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool(MockSubtractTool) + .tool_choice(ToolChoice::Specific { + function_names: vec!["add".to_string()], + }) + .build(); + + let mut stream = agent + .stream_prompt("use the allowed tool") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_tool_call = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { .. }, + )) => { + saw_tool_call = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_tool_call); + let error = error.expect("disallowed model-emitted tool should fail"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "subtract"); + assert_eq!( + available_tools, + vec!["add".to_string(), "subtract".to_string()] + ); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "subtract")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn mixed_specific_tool_calls_fail_before_any_tool_execution() { + let add_calls = Arc::new(AtomicU32::new(0)); + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call( + "tool_call_1", + "add", + serde_json::json!({"x": 1, "y": 2}), + ), + MockStreamEvent::tool_call( + "tool_call_2", + "subtract", + serde_json::json!({"x": 3, "y": 1}), + ), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(CountingAddTool { + calls: add_calls.clone(), + }) + .tool(MockSubtractTool) + .tool_choice(ToolChoice::Specific { + function_names: vec!["add".to_string()], + }) + .build(); + + let mut stream = agent + .stream_prompt("use the allowed tool") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_tool_call = false; + let mut saw_tool_result = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { .. }, + )) => { + saw_tool_call = true; + } + Ok(MultiTurnStreamItem::StreamUserItem(StreamedUserContent::ToolResult { + .. + })) => { + saw_tool_result = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_tool_call); + assert!(!saw_tool_result); + assert_eq!(add_calls.load(Ordering::SeqCst), 0); + let error = error.expect("mixed disallowed streamed tool call should fail"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "subtract"); + assert_eq!( + available_tools, + vec!["add".to_string(), "subtract".to_string()] + ); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "subtract")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn tool_choice_none_rejects_streaming_tool_call() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call( + "tool_call_1", + "add", + serde_json::json!({"x": 1, "y": 2}), + ), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool_choice(ToolChoice::None) + .build(); + + let mut stream = agent + .stream_prompt("do not use tools") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_tool_call = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCall { .. }, + )) => { + saw_tool_call = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_tool_call); + let error = error.expect("ToolChoice::None should reject returned tool calls"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "add"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert!(allowed_tools.is_empty()); + assert!(history_contains_tool_call(&chat_history, "add")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn tool_choice_none_rejects_streaming_tool_call_name_delta_before_hook_or_emit() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool_choice(ToolChoice::None) + .build(); + + let mut stream = agent + .stream_prompt("do not use tools") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_delta = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { .. }, + )) => { + saw_delta = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_delta); + let error = error.expect("ToolChoice::None should reject returned tool-call deltas"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "add"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert!(allowed_tools.is_empty()); + assert!(history_contains_tool_call(&chat_history, "add")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn unknown_tool_call_name_delta_fails_before_streaming_delta_hook_or_emit() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let mut stream = agent + .stream_prompt("stream a bad tool call") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_delta = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { .. }, + )) => { + saw_delta = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_delta); + let error = error.expect("unknown tool-call name delta should fail"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "default_api"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "default_api")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn tool_call_args_delta_before_unknown_name_fails_before_hook_or_emit() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"), + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "default_api"), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let mut stream = agent + .stream_prompt("stream a bad tool call") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_delta = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { .. }, + )) => { + saw_delta = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_delta); + let error = error.expect("unknown tool-call name should reject buffered args"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "default_api"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert_eq!(allowed_tools, vec!["add".to_string()]); + assert!(history_contains_tool_call(&chat_history, "default_api")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn tool_call_args_delta_before_valid_name_buffers_then_emits_in_safe_order() { + let model = MockCompletionModel::from_stream_turns([[ + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"), + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"), + MockStreamEvent::final_response_with_total_tokens(3), + ]]); + let hook = RecordingToolCallDeltaHook::default(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let mut stream = agent + .stream_prompt("stream a tool call") + .with_hook(hook.clone()) + .await; + let mut stream_deltas = Vec::new(); + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { + id, + internal_call_id, + content, + }, + )) => { + stream_deltas.push((id, internal_call_id, content)); + } + Ok(MultiTurnStreamItem::FinalResponse(_)) => break, + Ok(_) => {} + Err(err) => panic!("unexpected streaming error: {err:?}"), + } + } + + assert_eq!( + hook.observed(), + vec![ + ( + "tool_1".to_string(), + "internal_1".to_string(), + Some("add".to_string()), + String::new() + ), + ( + "tool_1".to_string(), + "internal_1".to_string(), + None, + "{\"x\":".to_string() + ), + ( + "tool_1".to_string(), + "internal_1".to_string(), + None, + "1}".to_string() + ), + ] + ); + assert_eq!( + stream_deltas, + vec![ + ( + "tool_1".to_string(), + "internal_1".to_string(), + ToolCallDeltaContent::Name("add".to_string()) + ), + ( + "tool_1".to_string(), + "internal_1".to_string(), + ToolCallDeltaContent::Delta("{\"x\":".to_string()) + ), + ( + "tool_1".to_string(), + "internal_1".to_string(), + ToolCallDeltaContent::Delta("1}".to_string()) + ), + ] + ); + } + + #[tokio::test] + async fn tool_call_args_delta_without_name_errors_at_stream_end() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let mut stream = agent + .stream_prompt("stream an incomplete tool call") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_delta = false; + let mut saw_completion_call = false; + let mut saw_final_response = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { .. }, + )) => { + saw_delta = true; + } + Ok(MultiTurnStreamItem::CompletionCall(_)) => { + saw_completion_call = true; + } + Ok(MultiTurnStreamItem::FinalResponse(_)) => { + saw_final_response = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_delta); + assert!(!saw_completion_call); + assert!(!saw_final_response); + let error = error.expect("unterminated tool-call args delta should fail"); + match error { + StreamingError::Completion(CompletionError::ResponseError(message)) => { + assert!( + message.contains("streamed tool call arguments"), + "{message}" + ); + assert!(message.contains("tool_1"), "{message}"); + assert!(message.contains("internal_1"), "{message}"); + } + other => panic!("expected completion response error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn tool_choice_none_buffers_args_then_rejects_name_without_emit() { + let model = MockCompletionModel::from_stream_turns([ + vec![ + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"), + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"), + MockStreamEvent::final_response_with_total_tokens(4), + ], + vec![ + MockStreamEvent::text("should not be requested"), + MockStreamEvent::final_response_with_total_tokens(6), + ], + ]); + let recorded = model.clone(); + let agent = AgentBuilder::new(model) + .tool(MockAddTool) + .tool_choice(ToolChoice::None) + .build(); + + let mut stream = agent + .stream_prompt("do not use tools") + .with_hook(PanicOnUnknownToolHook) + .multi_turn(3) + .await; + let mut saw_delta = false; + let mut error = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { .. }, + )) => { + saw_delta = true; + } + Ok(_) => {} + Err(err) => { + error = Some(err); + break; + } + } + } + + assert!(!saw_delta); + let error = error.expect("ToolChoice::None should reject buffered tool-call deltas"); + match error { + StreamingError::Prompt(err) => match *err { + PromptError::UnknownToolCall { + tool_name, + available_tools, + allowed_tools, + chat_history, + } => { + assert_eq!(tool_name, "add"); + assert_eq!(available_tools, vec!["add".to_string()]); + assert!(allowed_tools.is_empty()); + assert!(history_contains_tool_call(&chat_history, "add")); + } + other => panic!("expected UnknownToolCall, got {other:?}"), + }, + other => panic!("expected prompt streaming error, got {other:?}"), + } + assert_eq!(recorded.request_count(), 1); + } + + #[tokio::test] + async fn stream_prompt_emits_tool_call_deltas_without_hook() { + let model = MockCompletionModel::from_stream_turns([[ + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"), + MockStreamEvent::final_response_with_total_tokens(3), + ]]); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let mut stream = agent.stream_prompt("stream a tool call").await; + let mut deltas = Vec::new(); + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { + id, + internal_call_id, + content, + }, + )) => { + deltas.push((id, internal_call_id, content)); + } + Ok(MultiTurnStreamItem::FinalResponse(_)) => break, + Ok(_) => {} + Err(err) => panic!("unexpected streaming error: {err:?}"), + } + } + + assert_eq!( + deltas, + vec![ + ( + "tool_1".to_string(), + "internal_1".to_string(), + ToolCallDeltaContent::Name("add".to_string()) + ), + ( + "tool_1".to_string(), + "internal_1".to_string(), + ToolCallDeltaContent::Delta("{\"x\":".to_string()) + ), + ( + "tool_1".to_string(), + "internal_1".to_string(), + ToolCallDeltaContent::Delta("1}".to_string()) + ), + ] + ); + } + + #[tokio::test] + async fn stream_prompt_emits_tool_call_deltas_after_hook_continue() { + let model = MockCompletionModel::from_stream_turns([[ + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"), + MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "1}"), + MockStreamEvent::final_response_with_total_tokens(3), + ]]); + let hook = RecordingToolCallDeltaHook::default(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); + + let mut stream = agent + .stream_prompt("stream a tool call") + .with_hook(hook.clone()) + .await; + let mut stream_deltas = Vec::new(); + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem( + StreamedAssistantContent::ToolCallDelta { + id, + internal_call_id, + content, + }, + )) => { + stream_deltas.push((id, internal_call_id, content)); + } + Ok(MultiTurnStreamItem::FinalResponse(_)) => break, + Ok(_) => {} + Err(err) => panic!("unexpected streaming error: {err:?}"), + } + } + + assert_eq!( + hook.observed(), + vec![ + ( "tool_1".to_string(), "internal_1".to_string(), - Some("calculator".to_string()), + Some("add".to_string()), String::new() ), ( @@ -1668,7 +2937,7 @@ mod tests { ( "tool_1".to_string(), "internal_1".to_string(), - ToolCallDeltaContent::Name("calculator".to_string()) + ToolCallDeltaContent::Name("add".to_string()) ), ( "tool_1".to_string(), @@ -1687,12 +2956,12 @@ mod tests { #[tokio::test] async fn stream_prompt_tool_call_deltas_hook_termination_prevents_delta_emit() { let model = MockCompletionModel::from_stream_turns([[ - MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "calculator"), + MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "add"), MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":"), MockStreamEvent::final_response_with_total_tokens(3), ]]); let hook = TerminatingToolCallDeltaHook::default(); - let agent = AgentBuilder::new(model).build(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); let mut stream = agent .stream_prompt("stream a tool call") @@ -1725,7 +2994,7 @@ mod tests { vec![( "tool_1".to_string(), "internal_1".to_string(), - Some("calculator".to_string()), + Some("add".to_string()), String::new() )] ); @@ -1747,8 +3016,8 @@ mod tests { vec![ MockStreamEvent::tool_call( "tool_call_1", - "missing_tool", - serde_json::json!({"input": "value"}), + "add", + serde_json::json!({"x": 1, "y": 2}), ) .with_call_id("call_1"), MockStreamEvent::final_response(first_call_usage), @@ -1758,7 +3027,7 @@ mod tests { MockStreamEvent::final_response(second_call_usage), ], ]); - let agent = AgentBuilder::new(model).build(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); let empty_history: &[Message] = &[]; let mut stream = agent @@ -1859,8 +3128,8 @@ mod tests { vec![ MockStreamEvent::tool_call( "tool_call_1", - "missing_tool", - serde_json::json!({"input": "value"}), + "add", + serde_json::json!({"x": 1, "y": 2}), ) .with_call_id("call_1"), ], @@ -1869,7 +3138,7 @@ mod tests { MockStreamEvent::final_response(second_call_usage), ], ]); - let agent = AgentBuilder::new(model).build(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); let empty_history: &[Message] = &[]; let mut stream = agent @@ -2003,7 +3272,7 @@ mod tests { async fn tool_follow_up_history_preserves_structured_text_metadata() { let model = streaming_cited_text_then_tool_model(); let recorded = model.clone(); - let agent = AgentBuilder::new(model).build(); + let agent = AgentBuilder::new(model).tool(MockAddTool).build(); let empty_history: &[Message] = &[]; let mut stream = agent diff --git a/crates/rig-core/src/completion/request.rs b/crates/rig-core/src/completion/request.rs index 7a970df94..8e2b62656 100644 --- a/crates/rig-core/src/completion/request.rs +++ b/crates/rig-core/src/completion/request.rs @@ -116,6 +116,18 @@ pub enum PromptError { chat_history: Vec, reason: String, }, + + /// The model emitted a structured tool call for a tool Rig did not allow + /// for the current turn. + #[error( + "UnknownToolCall: model attempted to call unknown or disallowed tool `{tool_name}`. Available tools: {available_tools:?}. Allowed tools for this turn: {allowed_tools:?}" + )] + UnknownToolCall { + tool_name: String, + available_tools: Vec, + allowed_tools: Vec, + chat_history: Box>, + }, } /// Surface [`crate::memory::ConversationMemory`] failures through the existing diff --git a/crates/rig-core/src/providers/gemini/completion.rs b/crates/rig-core/src/providers/gemini/completion.rs index 2052ebf5f..e835522e8 100644 --- a/crates/rig-core/src/providers/gemini/completion.rs +++ b/crates/rig-core/src/providers/gemini/completion.rs @@ -37,7 +37,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest, GetTokenUsage}, }; use gemini_api_types::{ - Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, + Content, FinishReason, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Part, PartKind, Role, Tool, }; use serde_json::{Map, Value}; @@ -409,6 +409,25 @@ impl TryFrom> for Tool { } } +pub(crate) fn function_call_finish_reason_error( + reason: &FinishReason, + finish_message: Option<&str>, +) -> Option { + match reason { + FinishReason::MalformedFunctionCall + | FinishReason::UnexpectedToolCall + | FinishReason::MissingThoughtSignature + | FinishReason::TooManyToolCalls + | FinishReason::MalformedResponse => { + let message = finish_message.unwrap_or("no finish message provided"); + Some(CompletionError::ResponseError(format!( + "Gemini stopped with finish_reason={reason:?}: {message}" + ))) + } + _ => None, + } +} + impl TryFrom for completion::CompletionResponse { type Error = CompletionError; @@ -417,6 +436,13 @@ impl TryFrom for completion::CompletionResponse::try_from( + response, + ) + .expect_err("tool protocol finish reason should fail"); + + assert!(matches!( + err, + CompletionError::ResponseError(message) + if message.contains(&reason_name) + && message.contains(finish_message) + )); + } + } + #[test] fn test_completion_response_usage_preserves_cached_and_reasoning_tokens() { let response = GenerateContentResponse { diff --git a/crates/rig-core/src/providers/gemini/streaming.rs b/crates/rig-core/src/providers/gemini/streaming.rs index 2d3e9e34b..505d624d3 100644 --- a/crates/rig-core/src/providers/gemini/streaming.rs +++ b/crates/rig-core/src/providers/gemini/streaming.rs @@ -8,7 +8,8 @@ use super::completion::gemini_api_types::{ ContentCandidate, FinishReason, ModalityTokenCount, Part, PartKind, TrafficType, }; use super::completion::{ - CompletionModel, create_request_body, resolve_request_model, streaming_endpoint, + CompletionModel, create_request_body, function_call_finish_reason_error, resolve_request_model, + streaming_endpoint, }; use crate::completion::message::ReasoningContent; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; @@ -75,6 +76,8 @@ pub struct StreamingCompletionResponse { #[serde(skip_serializing_if = "Option::is_none")] pub finish_reason: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub finish_message: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub model_version: Option, } @@ -84,6 +87,11 @@ impl GetTokenUsage for StreamingCompletionResponse { } } +fn tool_protocol_finish_reason_error(choice: &ContentCandidate) -> Option { + let reason = choice.finish_reason.as_ref()?; + function_call_finish_reason_error(reason, choice.finish_message.as_deref()) +} + impl CompletionModel where T: HttpClientExt + Clone + 'static, @@ -138,7 +146,9 @@ where let stream = stream! { let mut final_usage = None; let mut final_finish_reason: Option = None; + let mut final_finish_message: Option = None; let mut final_model_version: Option = None; + let mut stream_failed = false; while let Some(event_result) = event_source.next().await { match event_result { Ok(Event::Open) => { @@ -155,7 +165,9 @@ where Ok(d) => d, Err(error) => { tracing::error!(?error, message = message.data, "Failed to parse SSE message"); - continue; + stream_failed = true; + yield Err(CompletionError::JsonError(error)); + break; } }; @@ -183,6 +195,15 @@ where if let Some(fr) = &choice.finish_reason { final_finish_reason = Some(fr.clone()); } + if let Some(message) = &choice.finish_message { + final_finish_message = Some(message.clone()); + } + + if let Some(err) = tool_protocol_finish_reason_error(&choice) { + stream_failed = true; + yield Err(err); + break; + } let Some(content) = choice.content else { tracing::debug!(finish_reason = ?final_finish_reason, "Streaming candidate missing content"); @@ -260,6 +281,7 @@ where } Err(error) => { tracing::error!(?error, "SSE error"); + stream_failed = true; yield Err(CompletionError::ProviderError(error.to_string())); break; } @@ -269,11 +291,14 @@ where // Ensure event source is closed when stream ends event_source.close(); - yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage_metadata: final_usage.unwrap_or_default(), - finish_reason: final_finish_reason, - model_version: final_model_version, - })); + if !stream_failed { + yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage_metadata: final_usage.unwrap_or_default(), + finish_reason: final_finish_reason, + finish_message: final_finish_message, + model_version: final_model_version, + })); + } }.instrument(span); Ok(streaming::StreamingCompletionResponse::stream(Box::pin( @@ -330,6 +355,61 @@ mod tests { } } + #[test] + fn test_streaming_tool_protocol_finish_reason_returns_response_error() { + for (finish_reason, reason_name, finish_message) in [ + ( + "MALFORMED_FUNCTION_CALL", + "MalformedFunctionCall", + "malformed function call: default_api", + ), + ( + "UNEXPECTED_TOOL_CALL", + "UnexpectedToolCall", + "unexpected tool call: default_api", + ), + ( + "MISSING_THOUGHT_SIGNATURE", + "MissingThoughtSignature", + "missing thought signature for tool call", + ), + ( + "TOO_MANY_TOOL_CALLS", + "TooManyToolCalls", + "too many tool calls in response", + ), + ( + "MALFORMED_RESPONSE", + "MalformedResponse", + "malformed response from provider", + ), + ] { + let json_data = json!({ + "candidates": [{ + "finishReason": finish_reason, + "finishMessage": finish_message, + "index": 0 + }] + }); + + let response: StreamGenerateContentResponse = + serde_json::from_value(json_data).unwrap(); + let candidate = response + .candidates + .first() + .expect("expected terminal candidate"); + let err = tool_protocol_finish_reason_error(candidate) + .expect("tool protocol finish reason should be an error"); + + assert!(matches!( + err, + CompletionError::ResponseError(message) + if message.contains(reason_name) + && message.contains(finish_message) + )); + } + } + #[test] fn test_deserialize_stream_response_with_usage_only_chunk() { let json_data = json!({ @@ -639,6 +719,7 @@ mod tests { let response = StreamingCompletionResponse { usage_metadata: PartialUsage::default(), finish_reason: Some(FinishReason::Stop), + finish_message: None, model_version: Some("gemini-2.5-pro-preview-05-06".to_string()), }; @@ -677,6 +758,7 @@ mod tests { traffic_type: None, }, finish_reason: Some(FinishReason::Stop), + finish_message: None, model_version: Some("gemini-2.0-flash-001".to_string()), }; diff --git a/tests/cassettes/gemini/tool_choice/none_nonstreaming_no_tools.yaml b/tests/cassettes/gemini/tool_choice/none_nonstreaming_no_tools.yaml new file mode 100644 index 000000000..81b594652 --- /dev/null +++ b/tests/cassettes/gemini/tool_choice/none_nonstreaming_no_tools.yaml @@ -0,0 +1,18 @@ +when: + path: /v1beta/models/gemini-2.5-flash:generateContent + method: POST + query_param: + - name: key + value: '[REDACTED]' + header: + - name: accept + value: '*/*' + - name: content-type + value: application/json + body: '{"contents":[{"parts":[{"text":"Calculate 20 + 22 directly in text. Do not call tools.","thought":false}],"role":"user"}],"generationConfig":null,"safetySettings":null,"systemInstruction":{"parts":[{"text":"You are a deterministic calculator test. Answer directly in text.","thought":false}],"role":"model"},"toolConfig":{"functionCallingConfig":{"mode":"NONE"}},"tools":[{"codeExecution":null,"functionDeclarations":[{"description":"Add x and y together","name":"add","parameters":{"properties":{"x":{"description":"The first number to add","type":"number"},"y":{"description":"The second number to add","type":"number"}},"required":["x","y"],"type":"object"}},{"description":"Subtract y from x (i.e.: x - y)","name":"subtract","parameters":{"properties":{"x":{"description":"The number to subtract from","type":"number"},"y":{"description":"The number to subtract","type":"number"}},"required":["x","y"],"type":"object"}}]}]}' +then: + status: 200 + header: + - name: content-type + value: application/json; charset=UTF-8 + body: '{"candidates":[{"content":{"parts":[{"text":"42"}],"role":"model"},"finishReason":"STOP","index":0}],"modelVersion":"gemini-2.5-flash","responseId":"id_REDACTED_1","usageMetadata":{"candidatesTokenCount":2,"promptTokenCount":146,"promptTokensDetails":[{"modality":"TEXT","tokenCount":146}],"serviceTier":"standard","thoughtsTokenCount":46,"totalTokenCount":194}}' diff --git a/tests/cassettes/gemini/tool_choice/none_streaming_no_tools.yaml b/tests/cassettes/gemini/tool_choice/none_streaming_no_tools.yaml new file mode 100644 index 000000000..106e87a97 --- /dev/null +++ b/tests/cassettes/gemini/tool_choice/none_streaming_no_tools.yaml @@ -0,0 +1,20 @@ +when: + path: /v1beta/models/gemini-2.5-flash:streamGenerateContent + method: POST + query_param: + - name: alt + value: sse + - name: key + value: '[REDACTED]' + header: + - name: accept + value: text/event-stream + - name: content-type + value: application/json + body: '{"contents":[{"parts":[{"text":"Calculate 20 + 22 directly in text. Do not call tools.","thought":false}],"role":"user"}],"generationConfig":null,"safetySettings":null,"systemInstruction":{"parts":[{"text":"You are a deterministic calculator test. Answer directly in text.","thought":false}],"role":"model"},"toolConfig":{"functionCallingConfig":{"mode":"NONE"}},"tools":[{"codeExecution":null,"functionDeclarations":[{"description":"Add x and y together","name":"add","parameters":{"properties":{"x":{"description":"The first number to add","type":"number"},"y":{"description":"The second number to add","type":"number"}},"required":["x","y"],"type":"object"}},{"description":"Subtract y from x (i.e.: x - y)","name":"subtract","parameters":{"properties":{"x":{"description":"The number to subtract from","type":"number"},"y":{"description":"The number to subtract","type":"number"}},"required":["x","y"],"type":"object"}}]}]}' +then: + status: 200 + header: + - name: content-type + value: text/event-stream + body: "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"42\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"modelVersion\":\"gemini-2.5-flash\",\"responseId\":\"id_REDACTED_1\",\"usageMetadata\":{\"candidatesTokenCount\":2,\"promptTokenCount\":146,\"promptTokensDetails\":[{\"modality\":\"TEXT\",\"tokenCount\":146}],\"serviceTier\":\"standard\",\"thoughtsTokenCount\":29,\"totalTokenCount\":177}}\r\n\r\n" diff --git a/tests/cassettes/gemini/tool_choice/specific_add_raw_nonstreaming.yaml b/tests/cassettes/gemini/tool_choice/specific_add_raw_nonstreaming.yaml new file mode 100644 index 000000000..f2f757fc9 --- /dev/null +++ b/tests/cassettes/gemini/tool_choice/specific_add_raw_nonstreaming.yaml @@ -0,0 +1,18 @@ +when: + path: /v1beta/models/gemini-2.5-flash:generateContent + method: POST + query_param: + - name: key + value: '[REDACTED]' + header: + - name: accept + value: '*/*' + - name: content-type + value: application/json + body: '{"contents":[{"parts":[{"text":"Use the add tool to calculate 20 + 22. Do not use subtraction.","thought":false}],"role":"user"}],"generationConfig":null,"safetySettings":null,"systemInstruction":null,"toolConfig":{"functionCallingConfig":{"allowed_function_names":["add"],"mode":"ANY"}},"tools":[{"codeExecution":null,"functionDeclarations":[{"description":"Add x and y together","name":"add","parameters":{"properties":{"x":{"description":"The first number to add","type":"number"},"y":{"description":"The second number to add","type":"number"}},"required":["x","y"],"type":"object"}},{"description":"Subtract y from x (i.e.: x - y)","name":"subtract","parameters":{"properties":{"x":{"description":"The number to subtract from","type":"number"},"y":{"description":"The number to subtract","type":"number"}},"required":["x","y"],"type":"object"}}]}]}' +then: + status: 200 + header: + - name: content-type + value: application/json; charset=UTF-8 + body: '{"candidates":[{"content":{"parts":[{"functionCall":{"args":{"x":20,"y":22},"name":"add"},"thoughtSignature":"signature_REDACTED_1"}],"role":"model"},"finishMessage":"Model generated function call(s).","finishReason":"STOP","index":0}],"modelVersion":"gemini-2.5-flash","responseId":"id_REDACTED_1","usageMetadata":{"candidatesTokenCount":20,"promptTokenCount":135,"promptTokensDetails":[{"modality":"TEXT","tokenCount":135}],"serviceTier":"standard","thoughtsTokenCount":67,"totalTokenCount":222}}' diff --git a/tests/cassettes/gemini/tool_choice/specific_add_raw_streaming.yaml b/tests/cassettes/gemini/tool_choice/specific_add_raw_streaming.yaml new file mode 100644 index 000000000..336975ec9 --- /dev/null +++ b/tests/cassettes/gemini/tool_choice/specific_add_raw_streaming.yaml @@ -0,0 +1,20 @@ +when: + path: /v1beta/models/gemini-2.5-flash:streamGenerateContent + method: POST + query_param: + - name: alt + value: sse + - name: key + value: '[REDACTED]' + header: + - name: accept + value: text/event-stream + - name: content-type + value: application/json + body: '{"contents":[{"parts":[{"text":"Use the add tool to calculate 20 + 22. Do not use subtraction.","thought":false}],"role":"user"}],"generationConfig":null,"safetySettings":null,"systemInstruction":null,"toolConfig":{"functionCallingConfig":{"allowed_function_names":["add"],"mode":"ANY"}},"tools":[{"codeExecution":null,"functionDeclarations":[{"description":"Add x and y together","name":"add","parameters":{"properties":{"x":{"description":"The first number to add","type":"number"},"y":{"description":"The second number to add","type":"number"}},"required":["x","y"],"type":"object"}},{"description":"Subtract y from x (i.e.: x - y)","name":"subtract","parameters":{"properties":{"x":{"description":"The number to subtract from","type":"number"},"y":{"description":"The number to subtract","type":"number"}},"required":["x","y"],"type":"object"}}]}]}' +then: + status: 200 + header: + - name: content-type + value: text/event-stream + body: "data: {\"candidates\":[{\"content\":{\"parts\":[{\"functionCall\":{\"args\":{\"x\":20,\"y\":22},\"name\":\"add\"},\"thoughtSignature\":\"signature_REDACTED_1\"}],\"role\":\"model\"},\"finishMessage\":\"Model generated function call(s).\",\"finishReason\":\"STOP\",\"index\":0}],\"modelVersion\":\"gemini-2.5-flash\",\"responseId\":\"id_REDACTED_1\",\"usageMetadata\":{\"candidatesTokenCount\":20,\"promptTokenCount\":135,\"promptTokensDetails\":[{\"modality\":\"TEXT\",\"tokenCount\":135}],\"serviceTier\":\"standard\",\"thoughtsTokenCount\":73,\"totalTokenCount\":228}}\r\n\r\n" diff --git a/tests/core/prompt_response_messages.rs b/tests/core/prompt_response_messages.rs index 6b947a16b..389103bb7 100644 --- a/tests/core/prompt_response_messages.rs +++ b/tests/core/prompt_response_messages.rs @@ -4,7 +4,7 @@ use rig::agent::AgentBuilder; use rig::completion::{Chat, Message, Prompt, Usage}; use rig::message::{AssistantContent, UserContent}; -use rig::test_utils::{MockCompletionModel, MockTurn}; +use rig::test_utils::{MockAddTool, MockCompletionModel, MockTurn}; // --------------------------------------------------------------------------- // Mock model infrastructure @@ -30,21 +30,17 @@ fn simple_text_model(turns: usize) -> MockCompletionModel { fn tool_then_text_model() -> MockCompletionModel { MockCompletionModel::new([ - MockTurn::tool_call( - "tc_1", - "calculator", - serde_json::json!({"op": "add", "a": 2, "b": 3}), - ) - .with_usage(Usage { - input_tokens: 15, - output_tokens: 8, - total_tokens: 23, - cached_input_tokens: 0, - cache_creation_input_tokens: 0, - tool_use_prompt_tokens: 0, - reasoning_tokens: 0, - }) - .with_message_id("msg_tool"), + MockTurn::tool_call("tc_1", "add", serde_json::json!({"x": 2, "y": 3})) + .with_usage(Usage { + input_tokens: 15, + output_tokens: 8, + total_tokens: 23, + cached_input_tokens: 0, + cache_creation_input_tokens: 0, + tool_use_prompt_tokens: 0, + reasoning_tokens: 0, + }) + .with_message_id("msg_tool"), MockTurn::text("The answer is 5") .with_usage(Usage { input_tokens: 20, @@ -60,7 +56,7 @@ fn tool_then_text_model() -> MockCompletionModel { } fn always_tool_call_turn() -> MockTurn { - MockTurn::tool_call("tc_loop", "infinite_tool", serde_json::json!({"x": 1})) + MockTurn::tool_call("tc_loop", "add", serde_json::json!({"x": 1, "y": 1})) } // --------------------------------------------------------------------------- @@ -176,7 +172,9 @@ async fn standard_with_history_works() { /// full conversation: User → Assistant(tool_call) → User(tool_result) → Assistant(text). #[tokio::test] async fn multi_turn_messages_include_tool_calls() { - let agent = AgentBuilder::new(tool_then_text_model()).build(); + let agent = AgentBuilder::new(tool_then_text_model()) + .tool(MockAddTool) + .build(); let resp = agent .prompt("What is 2 + 3?") @@ -191,8 +189,8 @@ async fn multi_turn_messages_include_tool_calls() { // Expected sequence: // [0] User: "What is 2 + 3?" - // [1] Assistant: ToolCall(calculator) - // [2] User: ToolResult (error since calculator tool isn't registered, but that's fine) + // [1] Assistant: ToolCall(add) + // [2] User: ToolResult // [3] Assistant: "The answer is 5" assert_eq!(messages.len(), 4, "expected 4 messages, got: {messages:#?}"); @@ -281,6 +279,7 @@ async fn max_turns_error_still_contains_history() { let agent = AgentBuilder::new(MockCompletionModel::new( (0..10).map(|_| always_tool_call_turn()), )) + .tool(MockAddTool) .build(); let result = agent @@ -311,7 +310,9 @@ async fn max_turns_error_still_contains_history() { /// be populated (this is the core feature: no need for &mut borrow). #[tokio::test] async fn extended_details_works_without_with_history() { - let agent = AgentBuilder::new(tool_then_text_model()).build(); + let agent = AgentBuilder::new(tool_then_text_model()) + .tool(MockAddTool) + .build(); // Note: NO .with_history() call — this is the new use case let resp = agent @@ -380,7 +381,9 @@ async fn chat_appends_prompt_and_assistant_to_history() { /// Test 11: `Chat::chat` appends every message produced by a tool roundtrip. #[tokio::test] async fn chat_appends_tool_roundtrip_to_history() { - let agent = AgentBuilder::new(tool_then_text_model()).build(); + let agent = AgentBuilder::new(tool_then_text_model()) + .tool(MockAddTool) + .build(); let mut history = Vec::::new(); let output = agent diff --git a/tests/providers/gemini/cassette/tool_choice.rs b/tests/providers/gemini/cassette/tool_choice.rs new file mode 100644 index 000000000..28a9e320d --- /dev/null +++ b/tests/providers/gemini/cassette/tool_choice.rs @@ -0,0 +1,229 @@ +//! Gemini tool-choice cassette coverage. + +use rig::client::CompletionClient; +use rig::completion::{AssistantContent, Chat, CompletionModel, Message}; +use rig::message::ToolChoice; +use rig::providers::gemini; +use rig::streaming::StreamingPrompt; +use rig::tool::Tool; + +use crate::support::{ + Adder, Subtract, assert_mentions_expected_number, collect_raw_stream_observation, + collect_stream_observation, +}; + +fn specific_add_choice() -> ToolChoice { + ToolChoice::Specific { + function_names: vec![Adder::NAME.to_string()], + } +} + +fn assert_history_tool_calls(history: &[Message], expected: &[&str], forbidden: &[&str]) { + let tool_names = history + .iter() + .filter_map(|message| match message { + Message::Assistant { content, .. } => Some(content), + _ => None, + }) + .flat_map(|content| content.iter()) + .filter_map(|content| match content { + AssistantContent::ToolCall(tool_call) => Some(tool_call.function.name.as_str()), + _ => None, + }) + .collect::>(); + + for expected_tool in expected { + assert!( + tool_names.iter().any(|name| name == expected_tool), + "expected tool call {expected_tool}, saw {tool_names:?}" + ); + } + + for forbidden_tool in forbidden { + assert!( + !tool_names.iter().any(|name| name == forbidden_tool), + "did not expect tool call {forbidden_tool}, saw {tool_names:?}" + ); + } +} + +#[tokio::test] +async fn specific_add_raw_streaming_allows_only_add() { + super::super::support::with_gemini_cassette( + "tool_choice/specific_add_raw_streaming", + |client| async move { + let model = client.completion_model(gemini::completion::GEMINI_2_5_FLASH); + let request = model + .completion_request( + "Use the add tool to calculate 20 + 22. Do not use subtraction.", + ) + .temperature(0.0) + .tool(Adder.definition(String::new()).await) + .tool(Subtract.definition(String::new()).await) + .tool_choice(specific_add_choice()) + .build(); + let stream = model.stream(request).await.expect("stream should start"); + let observation = collect_raw_stream_observation(stream).await; + + assert!( + observation.errors.is_empty(), + "stream should not emit errors: {:?}", + observation.errors + ); + assert!( + observation + .tool_calls + .iter() + .any(|tool_call| tool_call.function.name == Adder::NAME), + "expected add tool call, saw {:?}", + observation.tool_calls + ); + assert!( + !observation + .tool_calls + .iter() + .any(|tool_call| tool_call.function.name == Subtract::NAME), + "did not expect subtract tool call, saw {:?}", + observation.tool_calls + ); + let add_call = observation + .tool_calls + .iter() + .find(|tool_call| tool_call.function.name == Adder::NAME) + .expect("expected add tool call"); + assert_eq!( + add_call.function.arguments, + serde_json::json!({ "x": 20, "y": 22 }) + ); + }, + ) + .await; +} + +#[tokio::test] +async fn specific_add_raw_nonstreaming_allows_only_add() { + super::super::support::with_gemini_cassette( + "tool_choice/specific_add_raw_nonstreaming", + |client| async move { + let model = client.completion_model(gemini::completion::GEMINI_2_5_FLASH); + let response = model + .completion_request( + "Use the add tool to calculate 20 + 22. Do not use subtraction.", + ) + .temperature(0.0) + .tool(Adder.definition(String::new()).await) + .tool(Subtract.definition(String::new()).await) + .tool_choice(specific_add_choice()) + .send() + .await + .expect("specific add raw completion should succeed"); + + let tool_calls = response + .choice + .iter() + .filter_map(|content| match content { + AssistantContent::ToolCall(tool_call) => Some(tool_call), + _ => None, + }) + .collect::>(); + + assert!( + tool_calls + .iter() + .any(|tool_call| tool_call.function.name == Adder::NAME), + "expected add tool call, saw {tool_calls:?}" + ); + assert!( + !tool_calls + .iter() + .any(|tool_call| tool_call.function.name == Subtract::NAME), + "did not expect subtract tool call, saw {tool_calls:?}" + ); + let add_call = tool_calls + .iter() + .find(|tool_call| tool_call.function.name == Adder::NAME) + .expect("expected add tool call"); + assert_eq!( + add_call.function.arguments, + serde_json::json!({ "x": 20, "y": 22 }) + ); + }, + ) + .await; +} + +#[tokio::test] +async fn none_streaming_does_not_emit_tool_calls() { + super::super::support::with_gemini_cassette( + "tool_choice/none_streaming_no_tools", + |client| async move { + let agent = client + .agent(gemini::completion::GEMINI_2_5_FLASH) + .preamble("You are a deterministic calculator test. Answer directly in text.") + .temperature(0.0) + .tool(Adder) + .tool(Subtract) + .tool_choice(ToolChoice::None) + .build(); + + let mut stream = agent + .stream_prompt("Calculate 20 + 22 directly in text. Do not call tools.") + .await; + let observation = collect_stream_observation(&mut stream).await; + + assert!( + observation.errors.is_empty(), + "stream should not emit errors: {:?}", + observation.errors + ); + assert!( + observation.got_final_response, + "stream should emit a final response" + ); + assert!( + observation.tool_calls.is_empty(), + "expected no tool calls, saw {:?}", + observation.tool_calls + ); + assert_eq!(observation.tool_results, 0, "expected no tool results"); + assert_mentions_expected_number( + observation + .final_response_text + .as_deref() + .expect("stream should produce a final response"), + 42, + ); + }, + ) + .await; +} + +#[tokio::test] +async fn none_nonstreaming_does_not_emit_tool_calls() { + super::super::support::with_gemini_cassette( + "tool_choice/none_nonstreaming_no_tools", + |client| async move { + let agent = client + .agent(gemini::completion::GEMINI_2_5_FLASH) + .preamble("You are a deterministic calculator test. Answer directly in text.") + .temperature(0.0) + .tool(Adder) + .tool(Subtract) + .tool_choice(ToolChoice::None) + .build(); + + let mut chat_history = Vec::::new(); + let response = agent + .chat( + "Calculate 20 + 22 directly in text. Do not call tools.", + &mut chat_history, + ) + .await + .expect("ToolChoice::None prompt should succeed"); + + assert_mentions_expected_number(&response, 42); + assert_history_tool_calls(&chat_history, &[], &[Adder::NAME, Subtract::NAME]); + }, + ) + .await; +} diff --git a/tests/providers/gemini/mod.rs b/tests/providers/gemini/mod.rs index 4a131140f..4a6c184d6 100644 --- a/tests/providers/gemini/mod.rs +++ b/tests/providers/gemini/mod.rs @@ -14,6 +14,7 @@ mod cassette { mod streaming_multimodal_tool_results; mod streaming_tools; mod structured_output; + mod tool_choice; mod transcription; }