Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 56 additions & 20 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub struct AnthropicProvider {
model: ModelConfig,
supports_streaming: bool,
name: String,
custom_models: Option<Vec<String>>,
}

impl AnthropicProvider {
Expand All @@ -79,6 +80,7 @@ impl AnthropicProvider {
model,
supports_streaming: true,
name: ANTHROPIC_PROVIDER_NAME.to_string(),
custom_models: None,
})
}

Expand Down Expand Up @@ -118,11 +120,19 @@ impl AnthropicProvider {
));
}

// Extract custom models from config if available
let custom_models = if !config.models.is_empty() {
Some(config.models.iter().map(|m| m.name.clone()).collect())
} else {
None
};

Ok(Self {
api_client,
model,
supports_streaming,
name: config.name.clone(),
custom_models,
})
}

Expand All @@ -138,6 +148,31 @@ impl AnthropicProvider {

headers
}

async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
return Err(map_http_error_to_provider_error(
response.status,
response.payload,
));
}

let json = response.payload.unwrap_or_default();
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
}
}

impl ProviderDef for AnthropicProvider {
Expand Down Expand Up @@ -188,28 +223,29 @@ impl Provider for AnthropicProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
return Err(map_http_error_to_provider_error(
response.status,
response.payload,
));
// If custom models are defined, try API first but fallback to them only if endpoint doesn't exist
if let Some(custom_models) = &self.custom_models {
match self.fetch_models_from_api().await {
Ok(models) => return Ok(models),
Err(e) => {
// Only fall back for endpoint-not-implemented errors (404, connection failures)
// Auth errors, rate limits, and server errors should propagate
if e.is_endpoint_not_implemented() {
tracing::debug!(
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
self.name,
e
);
return Ok(custom_models.clone());
}
// Otherwise, propagate the error to preserve diagnostics
return Err(e);
}
}
}

let json = response.payload.unwrap_or_default();
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
// No custom models defined, must succeed with API call
self.fetch_models_from_api().await
}

async fn stream(
Expand Down
8 changes: 8 additions & 0 deletions crates/goose/src/providers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ impl ProviderError {
ProviderError::CreditsExhausted { .. } => "credits_exhausted",
}
}

/// Returns true if this error indicates the models endpoint is not implemented (404).
pub fn is_endpoint_not_implemented(&self) -> bool {
match self {
ProviderError::RequestFailed(msg) => msg.contains("404") || msg.contains("not found"),
Comment thread
octogonz marked this conversation as resolved.
Outdated
_ => false,
}
}
}

impl From<anyhow::Error> for ProviderError {
Expand Down
83 changes: 60 additions & 23 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub struct OpenAiProvider {
custom_headers: Option<HashMap<String, String>>,
supports_streaming: bool,
name: String,
custom_models: Option<Vec<String>>,
}

impl OpenAiProvider {
Expand Down Expand Up @@ -126,6 +127,7 @@ impl OpenAiProvider {
custom_headers,
supports_streaming: true,
name: OPEN_AI_PROVIDER_NAME.to_string(),
custom_models: None,
})
}

Expand All @@ -140,6 +142,7 @@ impl OpenAiProvider {
custom_headers: None,
supports_streaming: true,
name: OPEN_AI_PROVIDER_NAME.to_string(),
custom_models: None,
}
}

Expand Down Expand Up @@ -195,6 +198,13 @@ impl OpenAiProvider {
api_client = api_client.with_headers(header_map)?;
}

// Extract custom models from config if available
let custom_models = if !config.models.is_empty() {
Some(config.models.iter().map(|m| m.name.clone()).collect())
} else {
None
};

Ok(Self {
api_client,
base_path,
Expand All @@ -204,6 +214,7 @@ impl OpenAiProvider {
custom_headers: config.headers,
supports_streaming: config.supports_streaming.unwrap_or(true),
name: config.name.clone(),
custom_models,
})
}

Expand Down Expand Up @@ -267,6 +278,34 @@ impl OpenAiProvider {
fallback.to_string()
}
}

async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
.response_get()
.await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
}

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
}
}

impl ProviderDef for OpenAiProvider {
Expand Down Expand Up @@ -327,31 +366,29 @@ impl Provider for OpenAiProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
.response_get()
.await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
// If custom models are defined, try API first but fallback to them only if endpoint doesn't exist
if let Some(custom_models) = &self.custom_models {
match self.fetch_models_from_api().await {
Ok(models) => return Ok(models),
Err(e) => {
// Only fall back for endpoint-not-implemented errors (404, connection failures)
// Auth errors, rate limits, and server errors should propagate
if e.is_endpoint_not_implemented() {
tracing::debug!(
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
self.name,
e
);
return Ok(custom_models.clone());
}
// Otherwise, propagate the error to preserve diagnostics
return Err(e);
}
}
}

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
// No custom models defined, must succeed with API call
self.fetch_models_from_api().await
}

fn supports_embeddings(&self) -> bool {
Expand Down