diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 21338a732b18..a446797d87c9 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -55,6 +55,7 @@ pub struct AnthropicProvider { model: ModelConfig, supports_streaming: bool, name: String, + custom_models: Option>, } impl AnthropicProvider { @@ -80,6 +81,7 @@ impl AnthropicProvider { model, supports_streaming: true, name: ANTHROPIC_PROVIDER_NAME.to_string(), + custom_models: None, }) } @@ -119,11 +121,18 @@ impl AnthropicProvider { )); } + let custom_models = if !config.models.is_empty() { + Some(config.models.iter().map(|m| m.name.clone()).collect()) + } else { + None + }; + Ok(Self { api_client, model, supports_streaming, name: config.name.clone(), + custom_models, }) } @@ -139,6 +148,42 @@ impl AnthropicProvider { headers } + + async fn fetch_models_from_api(&self) -> Result, ProviderError> { + let response = self.api_client.request(None, "v1/models").api_get().await?; + + if response.status == StatusCode::NOT_FOUND { + let msg = response + .payload + .as_ref() + .and_then(|p| p.get("error").and_then(|e| e.get("message"))) + .and_then(|m| m.as_str()) + .unwrap_or("models endpoint not found") + .to_string(); + return Err(ProviderError::EndpointNotFound(msg)); + } + + if response.status != StatusCode::OK { + return Err(map_http_error_to_provider_error( + response.status, + response.payload, + )); + } + + let json = response.payload.unwrap_or_default(); + let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { + ProviderError::RequestFailed( + "Missing 'data' array in Anthropic models response".to_string(), + ) + })?; + + let mut models: Vec = arr + .iter() + .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) + .collect(); + models.sort(); + Ok(models) + } } impl ProviderDef for AnthropicProvider { @@ -194,28 +239,22 @@ impl Provider for AnthropicProvider { } async fn fetch_supported_models(&self) -> Result, ProviderError> { - let response = self.api_client.request(None, "v1/models").api_get().await?; - - if response.status != StatusCode::OK { - return Err(map_http_error_to_provider_error( - response.status, - response.payload, - )); + if let Some(custom_models) = &self.custom_models { + match self.fetch_models_from_api().await { + Ok(models) => return Ok(models), + Err(e) if e.is_endpoint_not_found() => { + tracing::debug!( + "Models endpoint not implemented for provider '{}' ({}), using predefined list", + self.name, + e + ); + return Ok(custom_models.clone()); + } + Err(e) => return Err(e), + } } - let json = response.payload.unwrap_or_default(); - let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { - ProviderError::RequestFailed( - "Missing 'data' array in Anthropic models response".to_string(), - ) - })?; - - let mut models: Vec = arr - .iter() - .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) - .collect(); - models.sort(); - Ok(models) + self.fetch_models_from_api().await } async fn stream( diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 214a14837d15..46a2aa4ca051 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -34,6 +34,9 @@ pub enum ProviderError { #[error("Unsupported operation: {0}")] NotImplemented(String), + #[error("Endpoint not found (404): {0}")] + EndpointNotFound(String), + #[error("Credits exhausted: {details}")] CreditsExhausted { details: String, @@ -53,9 +56,14 @@ impl ProviderError { ProviderError::ExecutionError(_) => "execution", ProviderError::UsageError(_) => "usage", ProviderError::NotImplemented(_) => "not_implemented", + ProviderError::EndpointNotFound(_) => "endpoint_not_found", ProviderError::CreditsExhausted { .. } => "credits_exhausted", } } + + pub fn is_endpoint_not_found(&self) -> bool { + matches!(self, ProviderError::EndpointNotFound(_)) + } } fn is_network_error(err: &reqwest::Error) -> bool { diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 5668d0df6a0a..d7d26177f0c2 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -65,6 +65,7 @@ pub struct OpenAiProvider { custom_headers: Option>, supports_streaming: bool, name: String, + custom_models: Option>, skip_canonical_filtering: bool, } @@ -127,6 +128,7 @@ impl OpenAiProvider { custom_headers, supports_streaming: true, name: OPEN_AI_PROVIDER_NAME.to_string(), + custom_models: None, skip_canonical_filtering: false, }) } @@ -142,6 +144,7 @@ impl OpenAiProvider { custom_headers: None, supports_streaming: true, name: OPEN_AI_PROVIDER_NAME.to_string(), + custom_models: None, skip_canonical_filtering: false, } } @@ -212,6 +215,12 @@ impl OpenAiProvider { api_client = api_client.with_headers(header_map)?; } + let custom_models = if !config.models.is_empty() { + Some(config.models.iter().map(|m| m.name.clone()).collect()) + } else { + None + }; + Ok(Self { api_client, base_path, @@ -221,6 +230,7 @@ impl OpenAiProvider { custom_headers: config.headers, supports_streaming: config.supports_streaming.unwrap_or(true), name: config.name.clone(), + custom_models, skip_canonical_filtering: config.skip_canonical_filtering, }) } @@ -314,6 +324,40 @@ impl OpenAiProvider { fallback.to_string() } } + + async fn fetch_models_from_api(&self) -> Result, ProviderError> { + let models_path = + Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH); + let response = self + .api_client + .request(None, &models_path) + .response_get() + .await?; + + if response.status() == StatusCode::NOT_FOUND { + let body = response.text().await.unwrap_or_default(); + return Err(ProviderError::EndpointNotFound(body)); + } + + let json = handle_response_openai_compat(response).await?; + if let Some(err_obj) = json.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + + let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { + ProviderError::UsageError("Missing data field in JSON response".into()) + })?; + let mut models: Vec = data + .iter() + .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) + .collect(); + models.sort(); + Ok(models) + } } impl ProviderDef for OpenAiProvider { @@ -384,31 +428,22 @@ impl Provider for OpenAiProvider { } async fn fetch_supported_models(&self) -> Result, ProviderError> { - let models_path = - Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH); - let response = self - .api_client - .request(None, &models_path) - .response_get() - .await?; - let json = handle_response_openai_compat(response).await?; - if let Some(err_obj) = json.get("error") { - let msg = err_obj - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error"); - return Err(ProviderError::Authentication(msg.to_string())); + if let Some(custom_models) = &self.custom_models { + match self.fetch_models_from_api().await { + Ok(models) => return Ok(models), + Err(e) if e.is_endpoint_not_found() => { + tracing::debug!( + "Models endpoint not implemented for provider '{}' ({}), using predefined list", + self.name, + e + ); + return Ok(custom_models.clone()); + } + Err(e) => return Err(e), + } } - let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { - ProviderError::UsageError("Missing data field in JSON response".into()) - })?; - let mut models: Vec = data - .iter() - .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) - .collect(); - models.sort(); - Ok(models) + self.fetch_models_from_api().await } fn supports_embeddings(&self) -> bool { @@ -635,6 +670,7 @@ mod tests { custom_headers: None, supports_streaming: true, name: name.to_string(), + custom_models: None, skip_canonical_filtering: false, } } diff --git a/crates/goose/src/providers/openai_compatible.rs b/crates/goose/src/providers/openai_compatible.rs index 313dfdfbf6c3..225072843848 100644 --- a/crates/goose/src/providers/openai_compatible.rs +++ b/crates/goose/src/providers/openai_compatible.rs @@ -299,6 +299,18 @@ mod tests { "ServerError" ; "500 server error" )] + #[test_case( + StatusCode::NOT_FOUND, + None, + "RequestFailed" + ; "404 not found" + )] + #[test_case( + StatusCode::NOT_FOUND, + Some(json!({"error": {"message": "model not available"}})), + "RequestFailed" + ; "404 with error payload" + )] fn http_status_maps_to_expected_error( status: StatusCode, payload: Option, @@ -312,6 +324,7 @@ mod tests { "Authentication" => "auth", "ContextLengthExceeded" => "context_length", "ServerError" => "server", + "RequestFailed" => "request", other => panic!("Unknown variant: {other}"), }; assert_eq!(