Skip to content

Commit 66d58df

Browse files
Merge pull request #33 from gold-silver-copper/codex/invalid-tool-call-hooks-v2
Add invalid tool call recovery hooks
2 parents 8273897 + c35bd6b commit 66d58df

5 files changed

Lines changed: 3102 additions & 157 deletions

File tree

crates/rig-core/src/agent/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ mod tool;
112112
pub use crate::message::Text;
113113
pub use builder::{AgentBuilder, NoToolConfig, WithBuilderTools, WithToolServerHandle};
114114
pub use completion::Agent;
115-
pub use prompt_request::hooks::{HookAction, PromptHook, ToolCallHookAction};
115+
pub use prompt_request::hooks::{
116+
HookAction, InvalidToolCallContext, InvalidToolCallHook, InvalidToolCallHookAction, PromptHook,
117+
ToolCallHookAction,
118+
};
116119
pub use prompt_request::streaming::{
117120
FinalResponse, MultiTurnStreamItem, StreamingError, StreamingPromptRequest, StreamingResult,
118121
stream_to_stdout,

crates/rig-core/src/agent/prompt_request/hooks.rs

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,98 @@
44
55
use crate::{
66
completion::CompletionModel,
7-
message::Message,
7+
message::{Message, ToolChoice},
88
wasm_compat::{WasmCompatSend, WasmCompatSync},
99
};
1010

11+
/// Context passed to [`InvalidToolCallHook`] when the model emits a tool call
12+
/// that Rig would reject before normal tool-call hooks or execution.
13+
#[derive(Debug, Clone)]
14+
pub struct InvalidToolCallContext {
15+
/// Tool name emitted by the model.
16+
pub tool_name: String,
17+
/// Provider-supplied tool call ID, when available.
18+
pub tool_call_id: Option<String>,
19+
/// Internal Rig call ID, when available.
20+
pub internal_call_id: Option<String>,
21+
/// JSON arguments emitted for the tool call, when available.
22+
pub args: Option<String>,
23+
/// Executable Rig tools advertised to the provider for this turn.
24+
pub available_tools: Vec<String>,
25+
/// Tools allowed by the active [`ToolChoice`] for this turn.
26+
pub allowed_tools: Vec<String>,
27+
/// Active tool choice for this turn.
28+
pub tool_choice: Option<ToolChoice>,
29+
/// Diagnostic chat history including the rejected model output when available.
30+
pub chat_history: Vec<Message>,
31+
/// Whether the rejected call came from the streaming path.
32+
pub is_streaming: bool,
33+
}
34+
35+
/// Trait for recovering from model-emitted tool calls that Rig rejects during
36+
/// validation.
37+
///
38+
/// The default behavior remains fail-fast. Attach an implementation with
39+
/// `.with_invalid_tool_call_hook(...)` to opt into recovery.
40+
pub trait InvalidToolCallHook<M>: Clone + WasmCompatSend + WasmCompatSync
41+
where
42+
M: CompletionModel,
43+
{
44+
/// Called when a model-emitted tool call is unknown or disallowed by the
45+
/// current request's tool choice.
46+
fn on_invalid_tool_call(
47+
&self,
48+
_context: &InvalidToolCallContext,
49+
) -> impl Future<Output = InvalidToolCallHookAction> + WasmCompatSend {
50+
async { InvalidToolCallHookAction::fail() }
51+
}
52+
}
53+
54+
impl<M> InvalidToolCallHook<M> for () where M: CompletionModel {}
55+
56+
/// Recovery action for invalid tool-call hooks.
57+
#[derive(Debug, Clone, PartialEq, Eq)]
58+
pub enum InvalidToolCallHookAction {
59+
/// Preserve Rig's default fail-fast behavior.
60+
Fail,
61+
/// Retry the model turn with corrective feedback.
62+
Retry { feedback: String },
63+
/// Rewrite only the emitted tool name. The repaired name is revalidated
64+
/// against registered tools and the current `ToolChoice` before use.
65+
Repair { tool_name: String },
66+
/// Treat an invalid structured tool call as skipped by returning synthetic
67+
/// feedback as its tool result. This does not execute the invalid tool.
68+
Skip { reason: String },
69+
}
70+
71+
impl InvalidToolCallHookAction {
72+
/// Preserve Rig's default fail-fast behavior.
73+
pub fn fail() -> Self {
74+
Self::Fail
75+
}
76+
77+
/// Retry the model turn with corrective feedback.
78+
pub fn retry(feedback: impl Into<String>) -> Self {
79+
Self::Retry {
80+
feedback: feedback.into(),
81+
}
82+
}
83+
84+
/// Repair the emitted tool name.
85+
pub fn repair(tool_name: impl Into<String>) -> Self {
86+
Self::Repair {
87+
tool_name: tool_name.into(),
88+
}
89+
}
90+
91+
/// Skip the invalid call with a synthetic tool result.
92+
pub fn skip(reason: impl Into<String>) -> Self {
93+
Self::Skip {
94+
reason: reason.into(),
95+
}
96+
}
97+
}
98+
1199
/// Trait for per-request hooks to observe tool call events.
12100
pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
13101
where

0 commit comments

Comments
 (0)