Skip to content
Closed
75 changes: 47 additions & 28 deletions crates/rig-core/src/agent/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,34 @@ where
self.default_conversation_id = Some(id.into());
self
}

/// Set the default hook for the agent.
///
/// This hook will be used for all prompt requests unless overridden
/// via `.with_hook()` on the request.
pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, ToolState>
where
P2: PromptHook<M>,
{
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: self.tool_state,
hook: Some(hook),
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}
}

impl<M> AgentBuilder<M, (), NoToolConfig>
Expand Down Expand Up @@ -482,34 +510,6 @@ where
}
}

/// Set the default hook for the agent.
///
/// This hook will be used for all prompt requests unless overridden
/// via `.with_hook()` on the request.
pub fn hook<P2>(self, hook: P2) -> AgentBuilder<M, P2, NoToolConfig>
where
P2: PromptHook<M>,
{
AgentBuilder {
name: self.name,
description: self.description,
model: self.model,
preamble: self.preamble,
static_context: self.static_context,
additional_params: self.additional_params,
max_tokens: self.max_tokens,
dynamic_context: self.dynamic_context,
temperature: self.temperature,
tool_choice: self.tool_choice,
default_max_turns: self.default_max_turns,
tool_state: self.tool_state,
hook: Some(hook),
output_schema: self.output_schema,
memory: self.memory,
default_conversation_id: self.default_conversation_id,
}
}

/// Build the agent with no tools configured.
///
/// An empty `ToolServer` will be created for the agent.
Expand Down Expand Up @@ -651,3 +651,22 @@ where
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{MockAddTool, MockCompletionModel};

#[derive(Clone)]
struct BuilderHook;

impl PromptHook<MockCompletionModel> for BuilderHook {}

#[test]
fn hook_can_be_set_after_tool_configuration() {
let _agent = AgentBuilder::new(MockCompletionModel::text("ok"))
.tool(MockAddTool)
.hook(BuilderHook)
.build();
}
}
203 changes: 194 additions & 9 deletions crates/rig-core/src/agent/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use crate::{
vector_store::{VectorStoreError, request::VectorSearchRequest},
wasm_compat::WasmCompatSend,
};
use std::{collections::HashMap, sync::Arc};
use std::{
collections::{BTreeSet, HashMap},
sync::Arc,
};

const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";

Expand All @@ -22,8 +25,53 @@ pub type DynamicContextStore = Arc<
)>,
>;

/// A prepared completion request plus the executable Rig tool names advertised
/// to the provider for this turn.
pub(crate) struct PreparedCompletionRequest<M: CompletionModel> {
pub(crate) builder: CompletionRequestBuilder<M>,
pub(crate) executable_tool_names: BTreeSet<String>,
pub(crate) allowed_tool_names: BTreeSet<String>,
}

pub(crate) fn allowed_tool_names_for_choice(
executable_tool_names: &BTreeSet<String>,
tool_choice: Option<&ToolChoice>,
) -> Result<BTreeSet<String>, CompletionError> {
let allowed = match tool_choice {
None | Some(ToolChoice::Auto | ToolChoice::Required) => executable_tool_names.clone(),
Some(ToolChoice::None) => BTreeSet::new(),
Some(ToolChoice::Specific { function_names }) => {
if function_names.is_empty() {
return Err(CompletionError::RequestError(
"ToolChoice::Specific requires at least one function name".into(),
));
}

let requested = function_names.iter().cloned().collect::<BTreeSet<String>>();
let missing = requested
.difference(executable_tool_names)
.cloned()
.collect::<Vec<_>>();

if !missing.is_empty() {
return Err(CompletionError::RequestError(
format!(
"ToolChoice::Specific requested unknown tool names: {missing:?}. Available tools: {:?}",
executable_tool_names.iter().collect::<Vec<_>>()
)
.into(),
));
}

requested
}
};

Ok(allowed)
}

/// Helper function to build a completion request from agent components.
/// This is used by both `Agent::completion()` and `PromptRequest::send()`.
/// This is used by `Agent::completion()` to preserve the public completion API.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn build_completion_request<M: CompletionModel>(
model: &Arc<M>,
Expand All @@ -39,6 +87,41 @@ pub(crate) async fn build_completion_request<M: CompletionModel>(
dynamic_context: &DynamicContextStore,
output_schema: Option<&schemars::Schema>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
Ok(build_prepared_completion_request(
model,
prompt,
chat_history,
preamble,
static_context,
temperature,
max_tokens,
additional_params,
tool_choice,
tool_server_handle,
dynamic_context,
output_schema,
)
.await?
.builder)
}

/// Helper function to build a completion request from agent components while
/// preserving the executable Rig tool names sent to the provider.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn build_prepared_completion_request<M: CompletionModel>(
model: &Arc<M>,
prompt: Message,
chat_history: &[Message],
preamble: Option<&str>,
static_context: &[Document],
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<&serde_json::Value>,
tool_choice: Option<&ToolChoice>,
tool_server_handle: &ToolServerHandle,
dynamic_context: &DynamicContextStore,
output_schema: Option<&schemars::Schema>,
) -> Result<PreparedCompletionRequest<M>, CompletionError> {
// Find the latest message in the chat history that contains RAG text
let rag_text = prompt.rag_text();
let rag_text = rag_text.or_else(|| {
Expand Down Expand Up @@ -73,7 +156,7 @@ pub(crate) async fn build_completion_request<M: CompletionModel>(
};

// If the agent has RAG text, we need to fetch the dynamic context and tools
let result = match &rag_text {
let (builder, executable_tool_names) = match &rag_text {
Some(text) => {
// Map over the vector to create async tasks
let search_futures = dynamic_context.iter().map(|(num_sample, index)| {
Expand Down Expand Up @@ -123,21 +206,31 @@ pub(crate) async fn build_completion_request<M: CompletionModel>(
.map_err(|_| {
CompletionError::RequestError("Failed to get tool definitions".into())
})?;

completion_request
.documents(fetched_context)
.tools(tooldefs)
let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect();

(
completion_request
.documents(fetched_context)
.tools(tooldefs),
executable_tool_names,
)
}
None => {
let tooldefs = tool_server_handle.get_tool_defs(None).await.map_err(|_| {
CompletionError::RequestError("Failed to get tool definitions".into())
})?;
let executable_tool_names = tooldefs.iter().map(|tool| tool.name.clone()).collect();

completion_request.tools(tooldefs)
(completion_request.tools(tooldefs), executable_tool_names)
}
};
let allowed_tool_names = allowed_tool_names_for_choice(&executable_tool_names, tool_choice)?;

Ok(result)
Ok(PreparedCompletionRequest {
builder,
executable_tool_names,
allowed_tool_names,
})
}

/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
Expand Down Expand Up @@ -450,3 +543,95 @@ where
TypedPromptRequest::from_agent(*self, prompt)
}
}

#[cfg(test)]
mod tests {
use super::*;

fn tool_names(names: &[&str]) -> BTreeSet<String> {
names.iter().map(|name| (*name).to_string()).collect()
}

#[test]
fn allowed_tool_names_defaults_to_all_executable_tools() {
let executable = tool_names(&["add", "subtract"]);

assert_eq!(
allowed_tool_names_for_choice(&executable, None).unwrap(),
executable
);
}

#[test]
fn allowed_tool_names_auto_and_required_allow_all_executable_tools() {
let executable = tool_names(&["add", "subtract"]);

assert_eq!(
allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Auto)).unwrap(),
executable
);
assert_eq!(
allowed_tool_names_for_choice(&executable, Some(&ToolChoice::Required)).unwrap(),
executable
);
}

#[test]
fn allowed_tool_names_none_allows_no_tools() {
let executable = tool_names(&["add", "subtract"]);

assert!(
allowed_tool_names_for_choice(&executable, Some(&ToolChoice::None))
.unwrap()
.is_empty()
);
}

#[test]
fn allowed_tool_names_specific_allows_requested_executable_tools() {
let executable = tool_names(&["add", "subtract"]);
let choice = ToolChoice::Specific {
function_names: vec!["add".to_string()],
};

assert_eq!(
allowed_tool_names_for_choice(&executable, Some(&choice)).unwrap(),
tool_names(&["add"])
);
}

#[test]
fn allowed_tool_names_specific_rejects_missing_tools() {
let executable = tool_names(&["add"]);
let choice = ToolChoice::Specific {
function_names: vec!["missing".to_string()],
};

let err = allowed_tool_names_for_choice(&executable, Some(&choice))
.expect_err("missing specific tool should fail before provider request");

assert!(matches!(
err,
CompletionError::RequestError(err)
if err.to_string().contains("missing")
&& err.to_string().contains("add")
));
}

#[test]
fn allowed_tool_names_specific_rejects_empty_names() {
let executable = tool_names(&["add"]);
let choice = ToolChoice::Specific {
function_names: vec![],
};

let err = allowed_tool_names_for_choice(&executable, Some(&choice))
.expect_err("empty specific tool choice should fail before provider request");

assert!(matches!(
err,
CompletionError::RequestError(err)
if err.to_string().contains("requires at least one function name")
));
}
}
Loading
Loading