@@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part};
43
43
use reqwest:: { Client , Method , Response } ;
44
44
use serde:: Serialize ;
45
45
use serde_json:: Value ;
46
+ use url:: Url ;
46
47
47
48
use std:: error:: Error ;
48
49
use std:: fs:: { create_dir_all, File } ;
@@ -62,9 +63,10 @@ pub struct OpenAIClientBuilder {
62
63
headers : Option < HeaderMap > ,
63
64
}
64
65
66
+ #[ derive( Debug ) ]
65
67
pub struct OpenAIClient {
66
68
api_endpoint : String ,
67
- api_key : String ,
69
+ api_key : Option < String > ,
68
70
organization : Option < String > ,
69
71
proxy : Option < String > ,
70
72
timeout : Option < u64 > ,
@@ -111,14 +113,13 @@ impl OpenAIClientBuilder {
111
113
}
112
114
113
115
pub fn build ( self ) -> Result < OpenAIClient , Box < dyn Error > > {
114
- let api_key = self . api_key . ok_or ( "API key is required" ) ?;
115
116
let api_endpoint = self . api_endpoint . unwrap_or_else ( || {
116
117
std:: env:: var ( "OPENAI_API_BASE" ) . unwrap_or_else ( |_| API_URL_V1 . to_owned ( ) )
117
118
} ) ;
118
119
119
120
Ok ( OpenAIClient {
120
121
api_endpoint,
121
- api_key,
122
+ api_key : self . api_key ,
122
123
organization : self . organization ,
123
124
proxy : self . proxy ,
124
125
timeout : self . timeout ,
@@ -133,7 +134,10 @@ impl OpenAIClient {
133
134
}
134
135
135
136
async fn build_request ( & self , method : Method , path : & str ) -> reqwest:: RequestBuilder {
136
- let url = format ! ( "{}/{}" , self . api_endpoint, path) ;
137
+ let url = self
138
+ . build_url_with_preserved_query ( path)
139
+ . unwrap_or_else ( |_| format ! ( "{}/{}" , self . api_endpoint, path) ) ;
140
+
137
141
let client = Client :: builder ( ) ;
138
142
139
143
#[ cfg( feature = "rustls" ) ]
@@ -153,9 +157,11 @@ impl OpenAIClient {
153
157
154
158
let client = client. build ( ) . unwrap ( ) ;
155
159
156
- let mut request = client
157
- . request ( method, url)
158
- . header ( "Authorization" , format ! ( "Bearer {}" , self . api_key) ) ;
160
+ let mut request = client. request ( method, url) ;
161
+
162
+ if let Some ( api_key) = & self . api_key {
163
+ request = request. header ( "Authorization" , format ! ( "Bearer {}" , api_key) ) ;
164
+ }
159
165
160
166
if let Some ( organization) = & self . organization {
161
167
request = request. header ( "openai-organization" , organization) ;
@@ -775,7 +781,22 @@ impl OpenAIClient {
775
781
let url = Self :: query_params ( limit, None , after, None , "batches" . to_string ( ) ) ;
776
782
self . get ( & url) . await
777
783
}
784
+ fn build_url_with_preserved_query ( & self , path : & str ) -> Result < String , url:: ParseError > {
785
+ let ( base, query_opt) = match self . api_endpoint . split_once ( '?' ) {
786
+ Some ( ( b, q) ) => ( b. trim_end_matches ( '/' ) , Some ( q) ) ,
787
+ None => ( self . api_endpoint . trim_end_matches ( '/' ) , None ) ,
788
+ } ;
778
789
790
+ let full_path = format ! ( "{}/{}" , base, path. trim_start_matches( '/' ) ) ;
791
+ let mut url = Url :: parse ( & full_path) ?;
792
+
793
+ if let Some ( query) = query_opt {
794
+ for ( k, v) in url:: form_urlencoded:: parse ( query. as_bytes ( ) ) {
795
+ url. query_pairs_mut ( ) . append_pair ( & k, & v) ;
796
+ }
797
+ }
798
+ Ok ( url. to_string ( ) )
799
+ }
779
800
fn query_params (
780
801
limit : Option < i64 > ,
781
802
order : Option < String > ,
0 commit comments