-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcustom_tool.rs
More file actions
92 lines (80 loc) · 2.62 KB
/
Copy pathcustom_tool.rs
File metadata and controls
92 lines (80 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use openai_agents::{
Agent, AgentsError, InputItem, Model, ModelProvider, ModelRequest, ModelResponse, OutputItem,
Result as AgentsResult, Runner, Usage, custom_tool,
};
use serde_json::{Value, json};
#[derive(Clone, Default)]
struct RawEditModel {
calls: Arc<Mutex<usize>>,
}
#[async_trait]
impl Model for RawEditModel {
async fn generate(&self, request: ModelRequest) -> AgentsResult<ModelResponse> {
let mut calls = self.calls.lock().expect("raw edit model lock");
*calls += 1;
let output = if *calls == 1 {
vec![OutputItem::CustomToolCall {
call_id: "call-raw-edit".to_owned(),
tool_name: "raw_editor".to_owned(),
input: "hello from raw input".to_owned(),
}]
} else {
let edited = request
.input
.iter()
.find_map(custom_tool_output)
.unwrap_or_else(|| "missing custom output".to_owned());
vec![OutputItem::Text {
text: format!("edited={edited}"),
}]
};
Ok(ModelResponse {
model: request.model,
output,
usage: Usage::default(),
response_id: Some(format!("resp-raw-edit-{calls}")),
request_id: None,
})
}
}
fn custom_tool_output(item: &InputItem) -> Option<String> {
let InputItem::Json { value } = item else {
return None;
};
(value.get("type").and_then(Value::as_str) == Some("custom_tool_call_output"))
.then(|| {
value
.get("output")
.and_then(Value::as_str)
.map(ToOwned::to_owned)
})
.flatten()
}
#[derive(Clone, Default)]
struct RawEditProvider {
model: Arc<RawEditModel>,
}
impl ModelProvider for RawEditProvider {
fn resolve(&self, _model: Option<&str>) -> Arc<dyn Model> {
self.model.clone()
}
}
#[tokio::main]
async fn main() -> Result<(), AgentsError> {
let raw_editor = custom_tool("raw_editor", "Edit raw text.", |_ctx, input| async move {
Ok::<_, AgentsError>(input.to_uppercase())
})
.with_format(json!({"type": "text"}));
let agent = Agent::builder("Raw editor")
.instructions("Use the raw editor for unstructured text edits.")
.custom_tool(raw_editor)
.build();
let result = Runner::new()
.with_model_provider(Arc::new(RawEditProvider::default()))
.run(&agent, "Uppercase this draft.")
.await?;
println!("{}", result.final_output.unwrap_or_default());
Ok(())
}