Skip to content

Commit 27e197e

Browse files
octogonzDouwe Osinga
andauthored
fix(providers): fall back to configured models when models endpoint fetch fails (#7530)
Signed-off-by: Pete Gonzalez <4673363+octogonz@users.noreply.github.com> Signed-off-by: Douwe Osinga <douwe@squareup.com> Co-authored-by: Pete Gonzalez <octogonz@users.noreply.github.com> Co-authored-by: Douwe Osinga <douwe@squareup.com>
1 parent 4a02379 commit 27e197e

File tree

4 files changed

+139
-43
lines changed

4 files changed

+139
-43
lines changed

crates/goose/src/providers/anthropic.rs

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub struct AnthropicProvider {
5555
model: ModelConfig,
5656
supports_streaming: bool,
5757
name: String,
58+
custom_models: Option<Vec<String>>,
5859
}
5960

6061
impl AnthropicProvider {
@@ -80,6 +81,7 @@ impl AnthropicProvider {
8081
model,
8182
supports_streaming: true,
8283
name: ANTHROPIC_PROVIDER_NAME.to_string(),
84+
custom_models: None,
8385
})
8486
}
8587

@@ -119,11 +121,18 @@ impl AnthropicProvider {
119121
));
120122
}
121123

124+
let custom_models = if !config.models.is_empty() {
125+
Some(config.models.iter().map(|m| m.name.clone()).collect())
126+
} else {
127+
None
128+
};
129+
122130
Ok(Self {
123131
api_client,
124132
model,
125133
supports_streaming,
126134
name: config.name.clone(),
135+
custom_models,
127136
})
128137
}
129138

@@ -139,6 +148,42 @@ impl AnthropicProvider {
139148

140149
headers
141150
}
151+
152+
async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
153+
let response = self.api_client.request(None, "v1/models").api_get().await?;
154+
155+
if response.status == StatusCode::NOT_FOUND {
156+
let msg = response
157+
.payload
158+
.as_ref()
159+
.and_then(|p| p.get("error").and_then(|e| e.get("message")))
160+
.and_then(|m| m.as_str())
161+
.unwrap_or("models endpoint not found")
162+
.to_string();
163+
return Err(ProviderError::EndpointNotFound(msg));
164+
}
165+
166+
if response.status != StatusCode::OK {
167+
return Err(map_http_error_to_provider_error(
168+
response.status,
169+
response.payload,
170+
));
171+
}
172+
173+
let json = response.payload.unwrap_or_default();
174+
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
175+
ProviderError::RequestFailed(
176+
"Missing 'data' array in Anthropic models response".to_string(),
177+
)
178+
})?;
179+
180+
let mut models: Vec<String> = arr
181+
.iter()
182+
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
183+
.collect();
184+
models.sort();
185+
Ok(models)
186+
}
142187
}
143188

144189
impl ProviderDef for AnthropicProvider {
@@ -194,28 +239,22 @@ impl Provider for AnthropicProvider {
194239
}
195240

196241
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
197-
let response = self.api_client.request(None, "v1/models").api_get().await?;
198-
199-
if response.status != StatusCode::OK {
200-
return Err(map_http_error_to_provider_error(
201-
response.status,
202-
response.payload,
203-
));
242+
if let Some(custom_models) = &self.custom_models {
243+
match self.fetch_models_from_api().await {
244+
Ok(models) => return Ok(models),
245+
Err(e) if e.is_endpoint_not_found() => {
246+
tracing::debug!(
247+
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
248+
self.name,
249+
e
250+
);
251+
return Ok(custom_models.clone());
252+
}
253+
Err(e) => return Err(e),
254+
}
204255
}
205256

206-
let json = response.payload.unwrap_or_default();
207-
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
208-
ProviderError::RequestFailed(
209-
"Missing 'data' array in Anthropic models response".to_string(),
210-
)
211-
})?;
212-
213-
let mut models: Vec<String> = arr
214-
.iter()
215-
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
216-
.collect();
217-
models.sort();
218-
Ok(models)
257+
self.fetch_models_from_api().await
219258
}
220259

221260
async fn stream(

crates/goose/src/providers/errors.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ pub enum ProviderError {
3434
#[error("Unsupported operation: {0}")]
3535
NotImplemented(String),
3636

37+
#[error("Endpoint not found (404): {0}")]
38+
EndpointNotFound(String),
39+
3740
#[error("Credits exhausted: {details}")]
3841
CreditsExhausted {
3942
details: String,
@@ -53,9 +56,14 @@ impl ProviderError {
5356
ProviderError::ExecutionError(_) => "execution",
5457
ProviderError::UsageError(_) => "usage",
5558
ProviderError::NotImplemented(_) => "not_implemented",
59+
ProviderError::EndpointNotFound(_) => "endpoint_not_found",
5660
ProviderError::CreditsExhausted { .. } => "credits_exhausted",
5761
}
5862
}
63+
64+
pub fn is_endpoint_not_found(&self) -> bool {
65+
matches!(self, ProviderError::EndpointNotFound(_))
66+
}
5967
}
6068

6169
fn is_network_error(err: &reqwest::Error) -> bool {

crates/goose/src/providers/openai.rs

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub struct OpenAiProvider {
6565
custom_headers: Option<HashMap<String, String>>,
6666
supports_streaming: bool,
6767
name: String,
68+
custom_models: Option<Vec<String>>,
6869
skip_canonical_filtering: bool,
6970
}
7071

@@ -127,6 +128,7 @@ impl OpenAiProvider {
127128
custom_headers,
128129
supports_streaming: true,
129130
name: OPEN_AI_PROVIDER_NAME.to_string(),
131+
custom_models: None,
130132
skip_canonical_filtering: false,
131133
})
132134
}
@@ -142,6 +144,7 @@ impl OpenAiProvider {
142144
custom_headers: None,
143145
supports_streaming: true,
144146
name: OPEN_AI_PROVIDER_NAME.to_string(),
147+
custom_models: None,
145148
skip_canonical_filtering: false,
146149
}
147150
}
@@ -212,6 +215,12 @@ impl OpenAiProvider {
212215
api_client = api_client.with_headers(header_map)?;
213216
}
214217

218+
let custom_models = if !config.models.is_empty() {
219+
Some(config.models.iter().map(|m| m.name.clone()).collect())
220+
} else {
221+
None
222+
};
223+
215224
Ok(Self {
216225
api_client,
217226
base_path,
@@ -221,6 +230,7 @@ impl OpenAiProvider {
221230
custom_headers: config.headers,
222231
supports_streaming: config.supports_streaming.unwrap_or(true),
223232
name: config.name.clone(),
233+
custom_models,
224234
skip_canonical_filtering: config.skip_canonical_filtering,
225235
})
226236
}
@@ -314,6 +324,40 @@ impl OpenAiProvider {
314324
fallback.to_string()
315325
}
316326
}
327+
328+
async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
329+
let models_path =
330+
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
331+
let response = self
332+
.api_client
333+
.request(None, &models_path)
334+
.response_get()
335+
.await?;
336+
337+
if response.status() == StatusCode::NOT_FOUND {
338+
let body = response.text().await.unwrap_or_default();
339+
return Err(ProviderError::EndpointNotFound(body));
340+
}
341+
342+
let json = handle_response_openai_compat(response).await?;
343+
if let Some(err_obj) = json.get("error") {
344+
let msg = err_obj
345+
.get("message")
346+
.and_then(|v| v.as_str())
347+
.unwrap_or("unknown error");
348+
return Err(ProviderError::Authentication(msg.to_string()));
349+
}
350+
351+
let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
352+
ProviderError::UsageError("Missing data field in JSON response".into())
353+
})?;
354+
let mut models: Vec<String> = data
355+
.iter()
356+
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
357+
.collect();
358+
models.sort();
359+
Ok(models)
360+
}
317361
}
318362

319363
impl ProviderDef for OpenAiProvider {
@@ -384,31 +428,22 @@ impl Provider for OpenAiProvider {
384428
}
385429

386430
async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
387-
let models_path =
388-
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
389-
let response = self
390-
.api_client
391-
.request(None, &models_path)
392-
.response_get()
393-
.await?;
394-
let json = handle_response_openai_compat(response).await?;
395-
if let Some(err_obj) = json.get("error") {
396-
let msg = err_obj
397-
.get("message")
398-
.and_then(|v| v.as_str())
399-
.unwrap_or("unknown error");
400-
return Err(ProviderError::Authentication(msg.to_string()));
431+
if let Some(custom_models) = &self.custom_models {
432+
match self.fetch_models_from_api().await {
433+
Ok(models) => return Ok(models),
434+
Err(e) if e.is_endpoint_not_found() => {
435+
tracing::debug!(
436+
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
437+
self.name,
438+
e
439+
);
440+
return Ok(custom_models.clone());
441+
}
442+
Err(e) => return Err(e),
443+
}
401444
}
402445

403-
let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
404-
ProviderError::UsageError("Missing data field in JSON response".into())
405-
})?;
406-
let mut models: Vec<String> = data
407-
.iter()
408-
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
409-
.collect();
410-
models.sort();
411-
Ok(models)
446+
self.fetch_models_from_api().await
412447
}
413448

414449
fn supports_embeddings(&self) -> bool {
@@ -635,6 +670,7 @@ mod tests {
635670
custom_headers: None,
636671
supports_streaming: true,
637672
name: name.to_string(),
673+
custom_models: None,
638674
skip_canonical_filtering: false,
639675
}
640676
}

crates/goose/src/providers/openai_compatible.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,18 @@ mod tests {
299299
"ServerError"
300300
; "500 server error"
301301
)]
302+
#[test_case(
303+
StatusCode::NOT_FOUND,
304+
None,
305+
"RequestFailed"
306+
; "404 not found"
307+
)]
308+
#[test_case(
309+
StatusCode::NOT_FOUND,
310+
Some(json!({"error": {"message": "model not available"}})),
311+
"RequestFailed"
312+
; "404 with error payload"
313+
)]
302314
fn http_status_maps_to_expected_error(
303315
status: StatusCode,
304316
payload: Option<Value>,
@@ -312,6 +324,7 @@ mod tests {
312324
"Authentication" => "auth",
313325
"ContextLengthExceeded" => "context_length",
314326
"ServerError" => "server",
327+
"RequestFailed" => "request",
315328
other => panic!("Unknown variant: {other}"),
316329
};
317330
assert_eq!(

0 commit comments

Comments
 (0)