Skip to content

Commit 670ee75

Browse files
authored
Merge pull request #158 from hiteshjoshi/main
Optional auth header and with api keys | works with azure
2 parents f9c8d44 + 849664a commit 670ee75

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ features = ["connect"]
4040
[dependencies.futures-util]
4141
version = "0.3.31"
4242
features = ["sink", "std"]
43+
44+
[dependencies.url]
45+
version = "2.5.4"

src/v1/api.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part};
4343
use reqwest::{Client, Method, Response};
4444
use serde::Serialize;
4545
use serde_json::Value;
46+
use url::Url;
4647

4748
use std::error::Error;
4849
use std::fs::{create_dir_all, File};
@@ -62,9 +63,10 @@ pub struct OpenAIClientBuilder {
6263
headers: Option<HeaderMap>,
6364
}
6465

66+
#[derive(Debug)]
6567
pub struct OpenAIClient {
6668
api_endpoint: String,
67-
api_key: String,
69+
api_key: Option<String>,
6870
organization: Option<String>,
6971
proxy: Option<String>,
7072
timeout: Option<u64>,
@@ -111,14 +113,13 @@ impl OpenAIClientBuilder {
111113
}
112114

113115
pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
114-
let api_key = self.api_key.ok_or("API key is required")?;
115116
let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
116117
std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
117118
});
118119

119120
Ok(OpenAIClient {
120121
api_endpoint,
121-
api_key,
122+
api_key: self.api_key,
122123
organization: self.organization,
123124
proxy: self.proxy,
124125
timeout: self.timeout,
@@ -133,7 +134,10 @@ impl OpenAIClient {
133134
}
134135

135136
async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
136-
let url = format!("{}/{}", self.api_endpoint, path);
137+
let url = self
138+
.build_url_with_preserved_query(path)
139+
.unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
140+
137141
let client = Client::builder();
138142

139143
#[cfg(feature = "rustls")]
@@ -153,9 +157,11 @@ impl OpenAIClient {
153157

154158
let client = client.build().unwrap();
155159

156-
let mut request = client
157-
.request(method, url)
158-
.header("Authorization", format!("Bearer {}", self.api_key));
160+
let mut request = client.request(method, url);
161+
162+
if let Some(api_key) = &self.api_key {
163+
request = request.header("Authorization", format!("Bearer {}", api_key));
164+
}
159165

160166
if let Some(organization) = &self.organization {
161167
request = request.header("openai-organization", organization);
@@ -775,7 +781,22 @@ impl OpenAIClient {
775781
let url = Self::query_params(limit, None, after, None, "batches".to_string());
776782
self.get(&url).await
777783
}
784+
fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
785+
let (base, query_opt) = match self.api_endpoint.split_once('?') {
786+
Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
787+
None => (self.api_endpoint.trim_end_matches('/'), None),
788+
};
778789

790+
let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
791+
let mut url = Url::parse(&full_path)?;
792+
793+
if let Some(query) = query_opt {
794+
for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
795+
url.query_pairs_mut().append_pair(&k, &v);
796+
}
797+
}
798+
Ok(url.to_string())
799+
}
779800
fn query_params(
780801
limit: Option<i64>,
781802
order: Option<String>,

0 commit comments

Comments
 (0)