Skip to content

Commit cc56c11

Browse files
committed
handle query properly
1 parent 5955d14 commit cc56c11

3 files changed

Lines changed: 46 additions & 31 deletions

File tree

async-openai/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ secrecy = { version = "0.10.3", features = ["serde"] }
5151
bytes = "1.9.0"
5252
eventsource-stream = "0.2.3"
5353
serde_urlencoded = "0.7.1"
54+
url = "2.5"
5455
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
5556
hmac = { version = "0.12", optional = true, default-features = false}
5657
sha2 = { version = "0.10", optional = true, default-features = false }

async-openai/src/client.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ impl<C: Config> Client<C> {
217217
request_builder = request_builder.headers(headers.clone());
218218
}
219219

220-
if let Some(query) = request_options.query() {
221-
request_builder = request_builder.query(query);
220+
if !request_options.query().is_empty() {
221+
request_builder = request_builder.query(request_options.query());
222222
}
223223

224224
Ok(request_builder.build()?)
@@ -247,8 +247,8 @@ impl<C: Config> Client<C> {
247247
request_builder = request_builder.headers(headers.clone());
248248
}
249249

250-
if let Some(query) = request_options.query() {
251-
request_builder = request_builder.query(query);
250+
if !request_options.query().is_empty() {
251+
request_builder = request_builder.query(request_options.query());
252252
}
253253

254254
Ok(request_builder.build()?)
@@ -274,8 +274,8 @@ impl<C: Config> Client<C> {
274274
request_builder = request_builder.headers(headers.clone());
275275
}
276276

277-
if let Some(query) = request_options.query() {
278-
request_builder = request_builder.query(query);
277+
if !request_options.query().is_empty() {
278+
request_builder = request_builder.query(request_options.query());
279279
}
280280

281281
Ok(request_builder.build()?)
@@ -306,8 +306,8 @@ impl<C: Config> Client<C> {
306306
request_builder = request_builder.headers(headers.clone());
307307
}
308308

309-
if let Some(query) = request_options.query() {
310-
request_builder = request_builder.query(query);
309+
if !request_options.query().is_empty() {
310+
request_builder = request_builder.query(request_options.query());
311311
}
312312

313313
Ok(request_builder.build()?)
@@ -339,8 +339,8 @@ impl<C: Config> Client<C> {
339339
request_builder = request_builder.headers(headers.clone());
340340
}
341341

342-
if let Some(query) = request_options.query() {
343-
request_builder = request_builder.query(query);
342+
if !request_options.query().is_empty() {
343+
request_builder = request_builder.query(request_options.query());
344344
}
345345

346346
Ok(request_builder.build()?)
@@ -372,8 +372,8 @@ impl<C: Config> Client<C> {
372372
request_builder = request_builder.headers(headers.clone());
373373
}
374374

375-
if let Some(query) = request_options.query() {
376-
request_builder = request_builder.query(query);
375+
if !request_options.query().is_empty() {
376+
request_builder = request_builder.query(request_options.query());
377377
}
378378

379379
Ok(request_builder.build()?)
@@ -406,8 +406,8 @@ impl<C: Config> Client<C> {
406406
request_builder = request_builder.headers(headers.clone());
407407
}
408408

409-
if let Some(query) = request_options.query() {
410-
request_builder = request_builder.query(query);
409+
if !request_options.query().is_empty() {
410+
request_builder = request_builder.query(request_options.query());
411411
}
412412

413413
Ok(request_builder.build()?)
@@ -440,8 +440,8 @@ impl<C: Config> Client<C> {
440440
request_builder = request_builder.headers(headers.clone());
441441
}
442442

443-
if let Some(query) = request_options.query() {
444-
request_builder = request_builder.query(query);
443+
if !request_options.query().is_empty() {
444+
request_builder = request_builder.query(request_options.query());
445445
}
446446

447447
let response = request_builder.send().await.map_err(OpenAIError::Reqwest)?;
@@ -591,8 +591,8 @@ impl<C: Config> Client<C> {
591591
request_builder = request_builder.headers(headers.clone());
592592
}
593593

594-
if let Some(query) = request_options.query() {
595-
request_builder = request_builder.query(query);
594+
if !request_options.query().is_empty() {
595+
request_builder = request_builder.query(request_options.query());
596596
}
597597

598598
let event_source = request_builder.eventsource().unwrap();
@@ -622,8 +622,8 @@ impl<C: Config> Client<C> {
622622
request_builder = request_builder.headers(headers.clone());
623623
}
624624

625-
if let Some(query) = request_options.query() {
626-
request_builder = request_builder.query(query);
625+
if !request_options.query().is_empty() {
626+
request_builder = request_builder.query(request_options.query());
627627
}
628628

629629
let event_source = request_builder.eventsource().unwrap();

async-openai/src/request_options.rs

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
use reqwest::header::HeaderMap;
22
use serde::Serialize;
3+
use url::Url;
34

4-
use crate::error::OpenAIError;
5+
use crate::{config::OPENAI_API_BASE, error::OpenAIError};
56

67
#[derive(Clone, Debug, Default)]
78
pub struct RequestOptions {
8-
query: Option<String>,
9+
query: Vec<(String, String)>,
910
headers: Option<HeaderMap>,
1011
}
1112

1213
impl RequestOptions {
1314
pub(crate) fn new() -> Self {
1415
Self {
15-
query: None,
16+
query: Vec::new(),
1617
headers: None,
1718
}
1819
}
@@ -49,18 +50,31 @@ impl RequestOptions {
4950
&mut self,
5051
query: &Q,
5152
) -> Result<(), OpenAIError> {
52-
let new_query = serde_urlencoded::to_string(query)
53-
.map_err(|e| OpenAIError::InvalidArgument(format!("Invalid query: {}", e)))?;
54-
if let Some(existing_query) = &self.query {
55-
self.query = Some(format!("{}&{}", existing_query, new_query));
56-
} else {
57-
self.query = Some(new_query);
53+
// Use serde_urlencoded::Serializer directly to handle any serializable type
54+
// similar to how reqwest does it. We create a temporary URL to use query_pairs_mut()
55+
// which allows us to handle any serializable type, not just top-level maps/structs.
56+
let mut url = Url::parse(OPENAI_API_BASE)
57+
.map_err(|e| OpenAIError::InvalidArgument(format!("Failed to create URL: {}", e)))?;
58+
59+
{
60+
let mut pairs = url.query_pairs_mut();
61+
let serializer = serde_urlencoded::Serializer::new(&mut pairs);
62+
63+
query
64+
.serialize(serializer)
65+
.map_err(|e| OpenAIError::InvalidArgument(format!("Invalid query: {}", e)))?;
5866
}
67+
68+
// Extract query pairs from the URL and append to our vec
69+
for (key, value) in url.query_pairs() {
70+
self.query.push((key.to_string(), value.to_string()));
71+
}
72+
5973
Ok(())
6074
}
6175

62-
pub(crate) fn query(&self) -> Option<&str> {
63-
self.query.as_deref()
76+
pub(crate) fn query(&self) -> &[(String, String)] {
77+
&self.query
6478
}
6579

6680
pub(crate) fn headers(&self) -> Option<&HeaderMap> {

0 commit comments

Comments
 (0)