Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 56 additions & 20 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub struct AnthropicProvider {
model: ModelConfig,
supports_streaming: bool,
name: String,
custom_models: Option<Vec<String>>,
}

impl AnthropicProvider {
Expand All @@ -79,6 +80,7 @@ impl AnthropicProvider {
model,
supports_streaming: true,
name: ANTHROPIC_PROVIDER_NAME.to_string(),
custom_models: None,
})
}

Expand Down Expand Up @@ -118,11 +120,19 @@ impl AnthropicProvider {
));
}

// Extract custom models from config if available
let custom_models = if !config.models.is_empty() {
Some(config.models.iter().map(|m| m.name.clone()).collect())
} else {
None
};

Ok(Self {
api_client,
model,
supports_streaming,
name: config.name.clone(),
custom_models,
})
}

Expand All @@ -138,6 +148,31 @@ impl AnthropicProvider {

headers
}

async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
return Err(map_http_error_to_provider_error(
response.status,
response.payload,
));
}

let json = response.payload.unwrap_or_default();
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
}
}

impl ProviderDef for AnthropicProvider {
Expand Down Expand Up @@ -188,28 +223,29 @@ impl Provider for AnthropicProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
return Err(map_http_error_to_provider_error(
response.status,
response.payload,
));
// If custom models are defined, try API first but fallback to them only if endpoint doesn't exist
if let Some(custom_models) = &self.custom_models {
match self.fetch_models_from_api().await {
Ok(models) => return Ok(models),
Err(e) => {
// Only fall back for endpoint-not-implemented errors (404, connection failures)
// Auth errors, rate limits, and server errors should propagate
if e.is_endpoint_not_implemented() {
tracing::debug!(
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
self.name,
e
);
return Ok(custom_models.clone());
}
// Otherwise, propagate the error to preserve diagnostics
return Err(e);
}
}
}

let json = response.payload.unwrap_or_default();
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
// No custom models defined, must succeed with API call
self.fetch_models_from_api().await
}

async fn stream(
Expand Down
31 changes: 31 additions & 0 deletions crates/goose/src/providers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,37 @@ impl ProviderError {
ProviderError::CreditsExhausted { .. } => "credits_exhausted",
}
}

/// Returns true if this error indicates the models endpoint is not implemented
/// and we should fall back to configured custom models.
///
/// Only certain errors should trigger fallback - specifically 404 (endpoint not found)
/// or connection failures that suggest the endpoint doesn't exist.
///
/// Critical errors that should NOT fallback:
/// - Authentication failures (401/403) - indicates misconfigured credentials
/// - Rate limits (429) - indicates service is functioning but rate-limited
/// - Server errors (5xx) - indicates service is having issues
/// - Context length exceeded - indicates API is working
pub fn is_endpoint_not_implemented(&self) -> bool {
match self {
// 404 or connection failures suggest endpoint doesn't exist - safe to fallback
ProviderError::RequestFailed(msg) => {
msg.contains("404")
|| msg.contains("not found")
|| msg.contains("connection failed")
|| msg.contains("failed to connect")
Comment thread
octogonz marked this conversation as resolved.
Outdated
Comment thread
octogonz marked this conversation as resolved.
Outdated
}
// Auth, rate limit, server, context errors - do NOT fallback
ProviderError::Authentication(_) => false,
ProviderError::RateLimitExceeded { .. } => false,
ProviderError::ServerError(_) => false,
ProviderError::ContextLengthExceeded(_) => false,
ProviderError::CreditsExhausted { .. } => false,
// Other errors may indicate the endpoint exists - safer to not fallback
_ => false,
}
}
}

impl From<anyhow::Error> for ProviderError {
Expand Down
83 changes: 60 additions & 23 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub struct OpenAiProvider {
custom_headers: Option<HashMap<String, String>>,
supports_streaming: bool,
name: String,
custom_models: Option<Vec<String>>,
}

impl OpenAiProvider {
Expand Down Expand Up @@ -126,6 +127,7 @@ impl OpenAiProvider {
custom_headers,
supports_streaming: true,
name: OPEN_AI_PROVIDER_NAME.to_string(),
custom_models: None,
})
}

Expand All @@ -140,6 +142,7 @@ impl OpenAiProvider {
custom_headers: None,
supports_streaming: true,
name: OPEN_AI_PROVIDER_NAME.to_string(),
custom_models: None,
}
}

Expand Down Expand Up @@ -195,6 +198,13 @@ impl OpenAiProvider {
api_client = api_client.with_headers(header_map)?;
}

// Extract custom models from config if available
let custom_models = if !config.models.is_empty() {
Some(config.models.iter().map(|m| m.name.clone()).collect())
} else {
None
};

Ok(Self {
api_client,
base_path,
Expand All @@ -204,6 +214,7 @@ impl OpenAiProvider {
custom_headers: config.headers,
supports_streaming: config.supports_streaming.unwrap_or(true),
name: config.name.clone(),
custom_models,
})
}

Expand Down Expand Up @@ -267,6 +278,34 @@ impl OpenAiProvider {
fallback.to_string()
}
}

async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
.response_get()
.await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
}

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
}
}

impl ProviderDef for OpenAiProvider {
Expand Down Expand Up @@ -327,31 +366,29 @@ impl Provider for OpenAiProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
.response_get()
.await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
// If custom models are defined, try API first but fallback to them only if endpoint doesn't exist
if let Some(custom_models) = &self.custom_models {
match self.fetch_models_from_api().await {
Ok(models) => return Ok(models),
Err(e) => {
// Only fall back for endpoint-not-implemented errors (404, connection failures)
// Auth errors, rate limits, and server errors should propagate
if e.is_endpoint_not_implemented() {
tracing::debug!(
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
self.name,
e
);
return Ok(custom_models.clone());
}
// Otherwise, propagate the error to preserve diagnostics
return Err(e);
}
}
}

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
// No custom models defined, must succeed with API call
self.fetch_models_from_api().await
}

fn supports_embeddings(&self) -> bool {
Expand Down