Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub struct AnthropicProvider {
model: ModelConfig,
supports_streaming: bool,
name: String,
custom_models: Option<Vec<String>>,
}

impl AnthropicProvider {
Expand All @@ -80,6 +81,7 @@ impl AnthropicProvider {
model,
supports_streaming: true,
name: ANTHROPIC_PROVIDER_NAME.to_string(),
custom_models: None,
})
}

Expand Down Expand Up @@ -119,11 +121,18 @@ impl AnthropicProvider {
));
}

let custom_models = if !config.models.is_empty() {
Some(config.models.iter().map(|m| m.name.clone()).collect())
} else {
None
};

Ok(Self {
api_client,
model,
supports_streaming,
name: config.name.clone(),
custom_models,
})
}

Expand All @@ -139,6 +148,31 @@ impl AnthropicProvider {

headers
}

async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
return Err(map_http_error_to_provider_error(
response.status,
response.payload,
));
}

let json = response.payload.unwrap_or_default();
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
}
}

impl ProviderDef for AnthropicProvider {
Expand Down Expand Up @@ -194,28 +228,22 @@ impl Provider for AnthropicProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let response = self.api_client.request(None, "v1/models").api_get().await?;

if response.status != StatusCode::OK {
return Err(map_http_error_to_provider_error(
response.status,
response.payload,
));
if let Some(custom_models) = &self.custom_models {
match self.fetch_models_from_api().await {
Ok(models) => return Ok(models),
Err(e) if e.is_endpoint_not_found() => {
tracing::debug!(
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
self.name,
e
);
return Ok(custom_models.clone());
}
Err(e) => return Err(e),
}
}

let json = response.payload.unwrap_or_default();
let arr = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::RequestFailed(
"Missing 'data' array in Anthropic models response".to_string(),
)
})?;

let mut models: Vec<String> = arr
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
self.fetch_models_from_api().await
}

async fn stream(
Expand Down
8 changes: 8 additions & 0 deletions crates/goose/src/providers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub enum ProviderError {
#[error("Unsupported operation: {0}")]
NotImplemented(String),

#[error("Endpoint not found (404): {0}")]
EndpointNotFound(String),

#[error("Credits exhausted: {details}")]
CreditsExhausted {
details: String,
Expand All @@ -53,9 +56,14 @@ impl ProviderError {
ProviderError::ExecutionError(_) => "execution",
ProviderError::UsageError(_) => "usage",
ProviderError::NotImplemented(_) => "not_implemented",
ProviderError::EndpointNotFound(_) => "endpoint_not_found",
ProviderError::CreditsExhausted { .. } => "credits_exhausted",
}
}

pub fn is_endpoint_not_found(&self) -> bool {
matches!(self, ProviderError::EndpointNotFound(_))
}
}

fn is_network_error(err: &reqwest::Error) -> bool {
Expand Down
76 changes: 53 additions & 23 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub struct OpenAiProvider {
custom_headers: Option<HashMap<String, String>>,
supports_streaming: bool,
name: String,
custom_models: Option<Vec<String>>,
}

impl OpenAiProvider {
Expand Down Expand Up @@ -126,6 +127,7 @@ impl OpenAiProvider {
custom_headers,
supports_streaming: true,
name: OPEN_AI_PROVIDER_NAME.to_string(),
custom_models: None,
})
}

Expand All @@ -140,6 +142,7 @@ impl OpenAiProvider {
custom_headers: None,
supports_streaming: true,
name: OPEN_AI_PROVIDER_NAME.to_string(),
custom_models: None,
}
}

Expand Down Expand Up @@ -199,6 +202,12 @@ impl OpenAiProvider {
api_client = api_client.with_headers(header_map)?;
}

let custom_models = if !config.models.is_empty() {
Some(config.models.iter().map(|m| m.name.clone()).collect())
} else {
None
};

Ok(Self {
api_client,
base_path,
Expand All @@ -208,6 +217,7 @@ impl OpenAiProvider {
custom_headers: config.headers,
supports_streaming: config.supports_streaming.unwrap_or(true),
name: config.name.clone(),
custom_models,
})
}

Expand Down Expand Up @@ -300,6 +310,34 @@ impl OpenAiProvider {
fallback.to_string()
}
}

async fn fetch_models_from_api(&self) -> Result<Vec<String>, ProviderError> {
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
.response_get()
.await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
}

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
}
}

impl ProviderDef for OpenAiProvider {
Expand Down Expand Up @@ -366,31 +404,22 @@ impl Provider for OpenAiProvider {
}

async fn fetch_supported_models(&self) -> Result<Vec<String>, ProviderError> {
let models_path =
Self::map_base_path(&self.base_path, "models", OPEN_AI_DEFAULT_MODELS_PATH);
let response = self
.api_client
.request(None, &models_path)
.response_get()
.await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
if let Some(custom_models) = &self.custom_models {
match self.fetch_models_from_api().await {
Ok(models) => return Ok(models),
Err(e) if e.is_endpoint_not_found() => {
tracing::debug!(
"Models endpoint not implemented for provider '{}' ({}), using predefined list",
self.name,
e
);
return Ok(custom_models.clone());
}
Err(e) => return Err(e),
}
}

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
.collect();
models.sort();
Ok(models)
self.fetch_models_from_api().await
}

fn supports_embeddings(&self) -> bool {
Expand Down Expand Up @@ -617,6 +646,7 @@ mod tests {
custom_headers: None,
supports_streaming: true,
name: name.to_string(),
custom_models: None,
}
}

Expand Down
17 changes: 14 additions & 3 deletions crates/goose/src/providers/openai_compatible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,7 @@ pub fn map_http_error_to_provider_error(
status,
extract_message()
)),
StatusCode::NOT_FOUND => {
ProviderError::RequestFailed(format!("Resource not found (404): {}", extract_message()))
}
StatusCode::NOT_FOUND => ProviderError::EndpointNotFound(extract_message()),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep inference 404s as RequestFailed for model hints

Mapping every 404 to EndpointNotFound here regresses the model-recommendation flow for failed inference calls. Agent::enhance_model_error in agents/reply_parts.rs only augments ProviderError::RequestFailed with “available models” suggestions, so a /chat/completions response like “model ... not found” now skips that path and surfaces a less actionable generic error. This is a user-visible regression from the previous RequestFailed mapping for 404 responses.

Useful? React with 👍 / 👎.

StatusCode::PAYMENT_REQUIRED => ProviderError::CreditsExhausted {
details: extract_message(),
top_up_url: None,
Expand Down Expand Up @@ -299,6 +297,18 @@ mod tests {
"ServerError"
; "500 server error"
)]
#[test_case(
StatusCode::NOT_FOUND,
None,
"EndpointNotFound"
; "404 endpoint not found"
)]
#[test_case(
StatusCode::NOT_FOUND,
Some(json!({"error": {"message": "model (404) not available"}})),
"EndpointNotFound"
; "404 with payload containing 404 substring"
)]
fn http_status_maps_to_expected_error(
status: StatusCode,
payload: Option<Value>,
Expand All @@ -312,6 +322,7 @@ mod tests {
"Authentication" => "auth",
"ContextLengthExceeded" => "context_length",
"ServerError" => "server",
"EndpointNotFound" => "endpoint_not_found",
other => panic!("Unknown variant: {other}"),
};
assert_eq!(
Expand Down
Loading