Skip to content

Commit 560f203

Browse files
committed
✨ add web search using openai agents
1 parent 9ff84c3 commit 560f203

File tree

3 files changed

+139
-2
lines changed

3 files changed

+139
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ clipboard = "0.5.0"
3939
reqwest = { version = "0.12.1", features = ["json"] }
4040
serde = { version = "1.0", features = ["derive"] }
4141
once_cell = "1.19.0"
42-
speakstream = { version = "0.1.2", path = "../speakstream" }
42+
speakstream = "0.1.2"
4343
windows = { version = "0.52.0", features = [
4444
"Win32_System_Com",
4545
"Win32_System_Com_StructuredStorage",

src/main.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use futures::stream::StreamExt; // For `.next()` on FuturesOrdered.
2525
use std::thread;
2626
use tempfile::Builder;
2727
mod record;
28+
mod web_search;
2829
use crate::default_device_sink::{
2930
default_device_name as get_default_output_device,
3031
list_output_devices as list_audio_output_devices, set_output_device, DefaultDeviceSink,
@@ -58,6 +59,7 @@ use tracing::{debug, error, info, instrument, warn};
5859
use tracing_appender::rolling::{RollingFileAppender, Rotation};
5960

6061
use speakstream::ss::SpeakStream;
62+
use web_search::search_web;
6163

6264
#[derive(Debug, Subcommand)]
6365
pub enum SubCommands {
@@ -434,6 +436,22 @@ fn call_fn(
434436
Err(err) => Some(format!("Failed to create runtime: {}", err)),
435437
},
436438

439+
"search_web" => match tokio::runtime::Runtime::new() {
440+
Ok(rt) => {
441+
let args: serde_json::Value = serde_json::from_str(fn_args).unwrap();
442+
let query = args["query"].as_str().unwrap_or("");
443+
let api_key = match std::env::var("OPENAI_API_KEY") {
444+
Ok(k) => k,
445+
Err(_) => return Some("OPENAI_API_KEY not set".to_string()),
446+
};
447+
match rt.block_on(search_web(&api_key, query)) {
448+
Ok(ans) => Some(ans),
449+
Err(err) => Some(format!("Web search failed: {}", err)),
450+
}
451+
}
452+
Err(err) => Some(format!("Failed to create runtime: {}", err)),
453+
},
454+
437455
"set_timer_at" => {
438456
let args: serde_json::Value = serde_json::from_str(fn_args).unwrap();
439457
let time_str = args["time"].as_str().unwrap();
@@ -1578,7 +1596,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
15781596
}))
15791597
.build().unwrap(),
15801598

1581-
ChatCompletionFunctionsArgs::default()
1599+
ChatCompletionFunctionsArgs::default()
15821600
.name("get_location")
15831601
.description("Returns an approximate location based on the machine's IP address.")
15841602
.parameters(json!({
@@ -1588,6 +1606,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
15881606
}))
15891607
.build().unwrap(),
15901608

1609+
ChatCompletionFunctionsArgs::default()
1610+
.name("search_web")
1611+
.description("Searches the web using OpenAI's browser tool.")
1612+
.parameters(json!({
1613+
"type": "object",
1614+
"properties": { "query": { "type": "string" } },
1615+
"required": ["query"],
1616+
}))
1617+
.build().unwrap(),
1618+
15911619
ChatCompletionFunctionsArgs::default()
15921620
.name("list_output_devices")
15931621
.description("Lists available audio output devices. The default device is marked with '(Default)'.")

src/web_search.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use reqwest::Client;
2+
use serde_json::json;
3+
use std::{error::Error, time::Duration};
4+
use tokio::time::sleep;
5+
6+
pub async fn search_web(api_key: &str, query: &str) -> Result<String, Box<dyn Error>> {
7+
let client = Client::new();
8+
let base = "https://api.openai.com/v1";
9+
10+
let assistant_res: serde_json::Value = client
11+
.post(&format!("{}/assistants", base))
12+
.header("Authorization", format!("Bearer {}", api_key))
13+
.header("OpenAI-Beta", "assistants=v1")
14+
.json(&json!({
15+
"model": "gpt-4o",
16+
"instructions": "You are a web search assistant.",
17+
"tools": [{"type": "browser"}]
18+
}))
19+
.send()
20+
.await?
21+
.json()
22+
.await?;
23+
24+
let assistant_id = assistant_res["id"]
25+
.as_str()
26+
.ok_or("missing assistant id")?
27+
.to_string();
28+
29+
let thread_res: serde_json::Value = client
30+
.post(&format!("{}/threads", base))
31+
.header("Authorization", format!("Bearer {}", api_key))
32+
.header("OpenAI-Beta", "assistants=v1")
33+
.send()
34+
.await?
35+
.json()
36+
.await?;
37+
38+
let thread_id = thread_res["id"]
39+
.as_str()
40+
.ok_or("missing thread id")?
41+
.to_string();
42+
43+
client
44+
.post(&format!("{}/threads/{}/messages", base, thread_id))
45+
.header("Authorization", format!("Bearer {}", api_key))
46+
.header("OpenAI-Beta", "assistants=v1")
47+
.json(&json!({"role": "user", "content": query}))
48+
.send()
49+
.await?;
50+
51+
let run_res: serde_json::Value = client
52+
.post(&format!("{}/threads/{}/runs", base, thread_id))
53+
.header("Authorization", format!("Bearer {}", api_key))
54+
.header("OpenAI-Beta", "assistants=v1")
55+
.json(&json!({"assistant_id": assistant_id}))
56+
.send()
57+
.await?
58+
.json()
59+
.await?;
60+
61+
let run_id = run_res["id"].as_str().ok_or("missing run id")?.to_string();
62+
63+
loop {
64+
let run_status: serde_json::Value = client
65+
.get(&format!("{}/threads/{}/runs/{}", base, thread_id, run_id))
66+
.header("Authorization", format!("Bearer {}", api_key))
67+
.header("OpenAI-Beta", "assistants=v1")
68+
.send()
69+
.await?
70+
.json()
71+
.await?;
72+
73+
match run_status["status"].as_str() {
74+
Some("completed") => break,
75+
Some("failed") | Some("expired") | Some("cancelled") => return Err("run failed".into()),
76+
_ => sleep(Duration::from_secs(1)).await,
77+
}
78+
}
79+
80+
let messages: serde_json::Value = client
81+
.get(&format!("{}/threads/{}/messages", base, thread_id))
82+
.header("Authorization", format!("Bearer {}", api_key))
83+
.header("OpenAI-Beta", "assistants=v1")
84+
.send()
85+
.await?
86+
.json()
87+
.await?;
88+
89+
let answer = messages["data"][0]["content"][0]["text"]["value"]
90+
.as_str()
91+
.unwrap_or("")
92+
.to_string();
93+
94+
// cleanup
95+
let _ = client
96+
.delete(&format!("{}/assistants/{}", base, assistant_id))
97+
.header("Authorization", format!("Bearer {}", api_key))
98+
.header("OpenAI-Beta", "assistants=v1")
99+
.send()
100+
.await;
101+
let _ = client
102+
.delete(&format!("{}/threads/{}", base, thread_id))
103+
.header("Authorization", format!("Bearer {}", api_key))
104+
.header("OpenAI-Beta", "assistants=v1")
105+
.send()
106+
.await;
107+
108+
Ok(answer)
109+
}

0 commit comments

Comments
 (0)