@@ -6,6 +6,7 @@ use log::{debug, error, info, trace, warn};
66use rand_distr:: Distribution ;
77use rayon:: iter:: split;
88use rayon:: prelude:: * ;
9+ use reqwest:: Url ;
910use reqwest_eventsource:: { Error , Event , EventSource } ;
1011use serde:: { Deserialize , Serialize } ;
1112use std:: cmp:: Ordering ;
@@ -58,7 +59,7 @@ impl Clone for Box<dyn TextGenerationBackend + Send + Sync> {
5859#[ derive( Debug , Clone ) ]
5960pub struct OpenAITextGenerationBackend {
6061 pub api_key : String ,
61- pub base_url : String ,
62+ pub base_url : Url ,
6263 pub model_name : String ,
6364 pub client : reqwest:: Client ,
6465 pub tokenizer : Arc < Tokenizer > ,
@@ -101,7 +102,7 @@ pub struct OpenAITextGenerationRequest {
101102impl OpenAITextGenerationBackend {
102103 pub fn try_new (
103104 api_key : String ,
104- base_url : String ,
105+ base_url : Url ,
105106 model_name : String ,
106107 tokenizer : Arc < Tokenizer > ,
107108 timeout : time:: Duration ,
@@ -128,7 +129,9 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
128129 request : Arc < TextGenerationRequest > ,
129130 sender : Sender < TextGenerationAggregatedResponse > ,
130131 ) {
131- let url = format ! ( "{base_url}/v1/chat/completions" , base_url = self . base_url) ;
132+ let mut url = self . base_url . clone ( ) ;
133+ url. set_path ( "/v1/chat/completions" ) ;
134+ // let url = format!("{base_url}", base_url = self.base_url);
132135 let mut aggregated_response = TextGenerationAggregatedResponse :: new ( request. clone ( ) ) ;
133136 let messages = vec ! [ OpenAITextGenerationMessage {
134137 role: "user" . to_string( ) ,
@@ -829,7 +832,7 @@ mod tests {
829832 w. write_all ( b"data: [DONE]\n \n " )
830833 } )
831834 . create_async ( ) . await ;
832- let url = s. url ( ) ;
835+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
833836 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
834837 let backend = OpenAITextGenerationBackend :: try_new (
835838 "" . to_string ( ) ,
@@ -890,7 +893,7 @@ mod tests {
890893 w. write_all ( b"data: [DONE]\n \n " )
891894 } )
892895 . create_async ( ) . await ;
893- let url = s. url ( ) ;
896+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
894897 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
895898 let backend = OpenAITextGenerationBackend :: try_new (
896899 "" . to_string ( ) ,
@@ -975,7 +978,7 @@ mod tests {
975978 . with_chunked_body ( |w| w. write_all ( b"data: {\" error\" : \" Internal server error\" }\n \n " ) )
976979 . create_async ( )
977980 . await ;
978- let url = s. url ( ) ;
981+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
979982 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
980983 let backend = OpenAITextGenerationBackend :: try_new (
981984 "" . to_string ( ) ,
@@ -1021,7 +1024,7 @@ mod tests {
10211024 . with_chunked_body ( |w| w. write_all ( b"this is wrong\n \n " ) )
10221025 . create_async ( )
10231026 . await ;
1024- let url = s. url ( ) ;
1027+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
10251028 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
10261029 let backend = OpenAITextGenerationBackend :: try_new (
10271030 "" . to_string ( ) ,
@@ -1067,7 +1070,7 @@ mod tests {
10671070 . with_chunked_body ( |w| w. write_all ( b"data: {\" foo\" : \" bar\" }\n \n " ) )
10681071 . create_async ( )
10691072 . await ;
1070- let url = s. url ( ) ;
1073+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
10711074 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
10721075 let backend = OpenAITextGenerationBackend :: try_new (
10731076 "" . to_string ( ) ,
@@ -1117,7 +1120,7 @@ mod tests {
11171120 w. write_all ( b"data: [DONE]\n \n " )
11181121 } )
11191122 . create_async ( ) . await ;
1120- let url = s. url ( ) ;
1123+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
11211124 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
11221125 let backend = OpenAITextGenerationBackend :: try_new (
11231126 "" . to_string ( ) ,
0 commit comments