Skip to content

Commit 059cec9

Browse files
committed
fix: openai add checking about fast model
for OpenAI compatible might does't support gpt-4o-mini Signed-off-by: Shawn Wang <shawn111@gmail.com>
1 parent 5a954d4 commit 059cec9

File tree

1 file changed

+38
-7
lines changed

1 file changed

+38
-7
lines changed

crates/goose/src/providers/openai.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,12 @@ impl_provider_default!(OpenAiProvider);
6060

6161
impl OpenAiProvider {
6262
pub fn from_env(model: ModelConfig) -> Result<Self> {
63-
let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string());
64-
6563
let config = crate::config::Config::global();
6664
let api_key: String = config.get_secret("OPENAI_API_KEY")?;
6765
let host: String = config
6866
.get_param("OPENAI_HOST")
6967
.unwrap_or_else(|_| "https://api.openai.com".to_string());
68+
7069
let base_path: String = config
7170
.get_param("OPENAI_BASE_PATH")
7271
.unwrap_or_else(|_| "v1/chat/completions".to_string());
@@ -80,8 +79,11 @@ impl OpenAiProvider {
8079
let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600);
8180

8281
let auth = AuthMethod::BearerToken(api_key);
83-
let mut api_client =
84-
ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
82+
let mut api_client = ApiClient::with_timeout(
83+
host.clone(),
84+
auth,
85+
std::time::Duration::from_secs(timeout_secs),
86+
)?;
8587

8688
if let Some(org) = &organization {
8789
api_client = api_client.with_header("OpenAI-Organization", org)?;
@@ -101,15 +103,44 @@ impl OpenAiProvider {
101103
api_client = api_client.with_headers(header_map)?;
102104
}
103105

104-
Ok(Self {
106+
let mut provider = Self {
105107
api_client,
106108
base_path,
107109
organization,
108110
project,
109-
model,
111+
model: model.clone(),
110112
custom_headers,
111113
supports_streaming: true,
112-
})
114+
};
115+
116+
let model_with_fast = tokio::task::block_in_place(|| {
117+
tokio::runtime::Handle::current().block_on(async {
118+
if let Ok(Some(models)) = provider.fetch_supported_models().await {
119+
if models.contains(&OPEN_AI_DEFAULT_FAST_MODEL.to_string()) {
120+
tracing::debug!(
121+
"Found {} in OpenAI workspace, setting as fast model",
122+
OPEN_AI_DEFAULT_FAST_MODEL
123+
);
124+
provider
125+
.model
126+
.clone()
127+
.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string())
128+
} else {
129+
tracing::debug!(
130+
"{} not found in OpenAI workspace, not setting fast model",
131+
OPEN_AI_DEFAULT_FAST_MODEL
132+
);
133+
provider.model.clone()
134+
}
135+
} else {
136+
tracing::debug!("Could not fetch OpenAI models, not setting fast model");
137+
provider.model.clone()
138+
}
139+
})
140+
});
141+
142+
provider.model = model_with_fast;
143+
Ok(provider)
113144
}
114145

115146
pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result<Self> {

0 commit comments

Comments
 (0)