@@ -60,13 +60,12 @@ impl_provider_default!(OpenAiProvider);
6060
6161impl OpenAiProvider {
6262 pub fn from_env ( model : ModelConfig ) -> Result < Self > {
63- let model = model. with_fast ( OPEN_AI_DEFAULT_FAST_MODEL . to_string ( ) ) ;
64-
6563 let config = crate :: config:: Config :: global ( ) ;
6664 let api_key: String = config. get_secret ( "OPENAI_API_KEY" ) ?;
6765 let host: String = config
6866 . get_param ( "OPENAI_HOST" )
6967 . unwrap_or_else ( |_| "https://api.openai.com" . to_string ( ) ) ;
68+
7069 let base_path: String = config
7170 . get_param ( "OPENAI_BASE_PATH" )
7271 . unwrap_or_else ( |_| "v1/chat/completions" . to_string ( ) ) ;
@@ -80,8 +79,11 @@ impl OpenAiProvider {
8079 let timeout_secs: u64 = config. get_param ( "OPENAI_TIMEOUT" ) . unwrap_or ( 600 ) ;
8180
8281 let auth = AuthMethod :: BearerToken ( api_key) ;
83- let mut api_client =
84- ApiClient :: with_timeout ( host, auth, std:: time:: Duration :: from_secs ( timeout_secs) ) ?;
82+ let mut api_client = ApiClient :: with_timeout (
83+ host. clone ( ) ,
84+ auth,
85+ std:: time:: Duration :: from_secs ( timeout_secs) ,
86+ ) ?;
8587
8688 if let Some ( org) = & organization {
8789 api_client = api_client. with_header ( "OpenAI-Organization" , org) ?;
@@ -101,15 +103,44 @@ impl OpenAiProvider {
101103 api_client = api_client. with_headers ( header_map) ?;
102104 }
103105
104- Ok ( Self {
106+ let mut provider = Self {
105107 api_client,
106108 base_path,
107109 organization,
108110 project,
109- model,
111+ model : model . clone ( ) ,
110112 custom_headers,
111113 supports_streaming : true ,
112- } )
114+ } ;
115+
116+ let model_with_fast = tokio:: task:: block_in_place ( || {
117+ tokio:: runtime:: Handle :: current ( ) . block_on ( async {
118+ if let Ok ( Some ( models) ) = provider. fetch_supported_models ( ) . await {
119+ if models. contains ( & OPEN_AI_DEFAULT_FAST_MODEL . to_string ( ) ) {
120+ tracing:: debug!(
121+ "Found {} in OpenAI workspace, setting as fast model" ,
122+ OPEN_AI_DEFAULT_FAST_MODEL
123+ ) ;
124+ provider
125+ . model
126+ . clone ( )
127+ . with_fast ( OPEN_AI_DEFAULT_FAST_MODEL . to_string ( ) )
128+ } else {
129+ tracing:: debug!(
130+ "{} not found in OpenAI workspace, not setting fast model" ,
131+ OPEN_AI_DEFAULT_FAST_MODEL
132+ ) ;
133+ provider. model . clone ( )
134+ }
135+ } else {
136+ tracing:: debug!( "Could not fetch OpenAI models, not setting fast model" ) ;
137+ provider. model . clone ( )
138+ }
139+ } )
140+ } ) ;
141+
142+ provider. model = model_with_fast;
143+ Ok ( provider)
113144 }
114145
115146 pub fn from_custom_config ( model : ModelConfig , config : CustomProviderConfig ) -> Result < Self > {
0 commit comments