From 139fb54e4f533949c3526d082d4205d0b995cd4d Mon Sep 17 00:00:00 2001 From: Istvan Benedek Date: Fri, 10 Apr 2026 21:05:09 +0100 Subject: [PATCH] fix(openai): normalize tool schemas for provider requests Signed-off-by: Istvan Benedek --- crates/goose/src/providers/chatgpt_codex.rs | 56 +- crates/goose/src/providers/formats/openai.rs | 950 +++++++++++++++++- .../src/providers/formats/openai_responses.rs | 77 +- 3 files changed, 1054 insertions(+), 29 deletions(-) diff --git a/crates/goose/src/providers/chatgpt_codex.rs b/crates/goose/src/providers/chatgpt_codex.rs index 6f4ab1ca3f9a..c54ac152c18f 100644 --- a/crates/goose/src/providers/chatgpt_codex.rs +++ b/crates/goose/src/providers/chatgpt_codex.rs @@ -4,6 +4,7 @@ use crate::model::ModelConfig; use crate::providers::api_client::AuthProvider; use crate::providers::base::{ConfigKey, MessageStream, Provider, ProviderDef, ProviderMetadata}; use crate::providers::errors::ProviderError; +use crate::providers::formats::openai::validate_tool_schemas; use crate::providers::formats::openai_responses::responses_api_to_streaming_message; use crate::providers::openai_compatible::handle_status_openai_compat; use crate::providers::retry::ProviderRetry; @@ -268,7 +269,7 @@ fn create_codex_request( .ok_or_else(|| anyhow!("Codex payload must be a JSON object"))?; if !tools.is_empty() { - let tools_spec: Vec = tools + let mut tools_spec: Vec = tools .iter() .map(|tool| { json!({ @@ -280,6 +281,8 @@ fn create_codex_request( }) .collect(); + validate_tool_schemas(&mut tools_spec); + payload_obj.insert("tools".to_string(), json!(tools_spec)); payload_obj.insert("tool_choice".to_string(), json!("auto")); payload_obj.insert("parallel_tool_calls".to_string(), json!(true)); @@ -1015,7 +1018,7 @@ mod tests { use crate::conversation::message::Message; use goose_test_support::TEST_IMAGE_B64; use jsonwebtoken::{Algorithm, EncodingKey, Header}; - use rmcp::model::{CallToolRequestParams, CallToolResult, Content, ErrorCode, ErrorData}; + use rmcp::model::{CallToolRequestParams, CallToolResult, Content, ErrorCode, ErrorData, Tool}; use rmcp::object; use test_case::test_case; use wiremock::matchers::{body_string_contains, method, path}; @@ -1042,6 +1045,17 @@ mod tests { .unwrap_or_default() } + fn schema_contains_key(value: &Value, needle: &str) -> bool { + match value { + Value::Object(map) => { + map.contains_key(needle) + || map.values().any(|child| schema_contains_key(child, needle)) + } + Value::Array(items) => items.iter().any(|child| schema_contains_key(child, needle)), + _ => false, + } + } + #[test_case( vec![ Message::user().with_text("user text"), @@ -1312,4 +1326,42 @@ mod tests { let instructions = payload["instructions"].as_str().unwrap(); assert_eq!(instructions, "system prompt"); } + + #[test] + fn test_codex_request_sanitizes_tool_schema() { + let model = ModelConfig::new("gpt-5.4").unwrap(); + let tool = Tool::new( + "render_treemap", + "Render a treemap", + object!({ + "$defs": { + "TreemapNode": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "children": { + "type": ["array", "null"], + "items": { "$ref": "#/$defs/TreemapNode" } + } + }, + "required": ["name"], + "additionalProperties": false + } + }, + "$ref": "#/$defs/TreemapNode" + }), + ); + + let payload = create_codex_request(&model, "system prompt", &[], &[tool]).unwrap(); + let parameters = &payload["tools"][0]["parameters"]; + + assert!(!schema_contains_key(parameters, "$defs")); + assert!(!schema_contains_key(parameters, "$ref")); + assert!(!schema_contains_key(parameters, "anyOf")); + assert_eq!(parameters["type"], "object"); + assert_eq!( + parameters["properties"]["children"]["items"]["type"], + "object" + ); + } } diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index d0029c5cafd8..8949736bf170 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -18,7 +18,7 @@ use rmcp::model::{ use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::ops::Deref; type ToolCallData = HashMap< @@ -31,6 +31,247 @@ type ToolCallData = HashMap< ), >; +type JsonSchemaMap = serde_json::Map; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum SchemaNormalizationError { + UnsupportedLocalRef(String), + MissingDefinition(String), + RecursiveReference(String), +} + +#[derive(Default)] +struct SchemaNormalizer { + ref_stack: Vec, +} + +impl SchemaNormalizer { + fn normalize(&mut self, schema: Value) -> Value { + let definitions = JsonSchemaMap::new(); + self.sanitize_with_fallback(schema, &definitions) + } + + fn sanitize_with_fallback( + &mut self, + node: Value, + inherited_definitions: &JsonSchemaMap, + ) -> Value { + self.sanitize_schema_node(node, inherited_definitions) + .unwrap_or_else(Self::fallback_schema_for_error) + } + + fn sanitize_schema_node( + &mut self, + node: Value, + inherited_definitions: &JsonSchemaMap, + ) -> Result { + match node { + Value::Object(object) => self.sanitize_schema_object(object, inherited_definitions), + Value::Array(items) => Ok(Value::Array( + items + .into_iter() + .map(|item| self.sanitize_with_fallback(item, inherited_definitions)) + .collect(), + )), + other => Ok(other), + } + } + + fn sanitize_schema_object( + &mut self, + object: JsonSchemaMap, + inherited_definitions: &JsonSchemaMap, + ) -> Result { + let definitions = merge_schema_definitions(inherited_definitions, &object); + + if is_null_only_schema(&object) { + return Ok(Value::Null); + } + + if let Some(reference) = object.get("$ref").and_then(|value| value.as_str()) { + let mut resolved = self + .resolve_schema_ref(reference, &definitions) + .unwrap_or_else(Self::fallback_schema_for_error); + if let Some(description) = object.get("description").cloned() { + if let Some(resolved_obj) = resolved.as_object_mut() { + resolved_obj + .entry("description".to_string()) + .or_insert(description); + } + } + return Ok(resolved); + } + + if let Some(branches) = object + .get("anyOf") + .or_else(|| object.get("oneOf")) + .and_then(|value| value.as_array()) + { + let mut merged = merge_union_schemas( + branches + .iter() + .cloned() + .map(|branch| self.sanitize_with_fallback(branch, &definitions)) + .collect(), + ); + + if let Some(description) = object.get("description").cloned() { + if let Some(merged_obj) = merged.as_object_mut() { + merged_obj + .entry("description".to_string()) + .or_insert(description); + } + } + + return Ok(merged); + } + + if let Some(branches) = object.get("allOf").and_then(|value| value.as_array()) { + let mut merged = merge_all_of_schemas( + branches + .iter() + .cloned() + .map(|branch| self.sanitize_with_fallback(branch, &definitions)) + .collect(), + ); + + if let Some(description) = object.get("description").cloned() { + if let Some(merged_obj) = merged.as_object_mut() { + merged_obj + .entry("description".to_string()) + .or_insert(description); + } + } + + return Ok(merged); + } + + let mut sanitized = JsonSchemaMap::new(); + + if let Some(description) = object.get("description").cloned() { + sanitized.insert("description".to_string(), description); + } + + if let Some(type_value) = normalized_type_value_from_schema(&object) { + sanitized.insert("type".to_string(), type_value); + } + + if let Some(enum_values) = object.get("enum").cloned() { + sanitized.insert("enum".to_string(), enum_values); + } else if let Some(const_value) = object.get("const").cloned() { + sanitized.insert("enum".to_string(), Value::Array(vec![const_value])); + } + + copy_schema_keyword(&object, &mut sanitized, "format"); + copy_schema_keyword(&object, &mut sanitized, "minimum"); + copy_schema_keyword(&object, &mut sanitized, "maximum"); + copy_schema_keyword(&object, &mut sanitized, "exclusiveMinimum"); + copy_schema_keyword(&object, &mut sanitized, "exclusiveMaximum"); + copy_schema_keyword(&object, &mut sanitized, "multipleOf"); + copy_schema_keyword(&object, &mut sanitized, "minLength"); + copy_schema_keyword(&object, &mut sanitized, "maxLength"); + copy_schema_keyword(&object, &mut sanitized, "pattern"); + copy_schema_keyword(&object, &mut sanitized, "minItems"); + copy_schema_keyword(&object, &mut sanitized, "maxItems"); + + if schema_allows_type(&object, "object") { + let properties = object + .get("properties") + .and_then(|value| value.as_object()) + .map(|properties| { + properties + .iter() + .map(|(name, value)| { + ( + name.clone(), + self.sanitize_with_fallback(value.clone(), &definitions), + ) + }) + .collect() + }) + .unwrap_or_default(); + + let required = object + .get("required") + .and_then(|value| value.as_array()) + .map(|required| { + required + .iter() + .filter_map(|value| value.as_str().map(|s| Value::String(s.to_string()))) + .collect() + }) + .unwrap_or_default(); + + sanitized.insert("properties".to_string(), Value::Object(properties)); + sanitized.insert("required".to_string(), Value::Array(required)); + + if let Some(additional_properties) = object.get("additionalProperties") { + match additional_properties { + Value::Bool(value) => { + sanitized.insert("additionalProperties".to_string(), Value::Bool(*value)); + } + Value::Object(value) => { + sanitized.insert( + "additionalProperties".to_string(), + self.sanitize_with_fallback(Value::Object(value.clone()), &definitions), + ); + } + _ => {} + } + } + } + + if schema_allows_type(&object, "array") { + if let Some(items) = object.get("items").cloned() { + sanitized.insert( + "items".to_string(), + self.sanitize_with_fallback(items, &definitions), + ); + } + } + + Ok(Value::Object(sanitized)) + } + + fn resolve_schema_ref( + &mut self, + reference: &str, + definitions: &JsonSchemaMap, + ) -> Result { + let name = reference + .strip_prefix("#/$defs/") + .or_else(|| reference.strip_prefix("#/definitions/")) + .ok_or_else(|| SchemaNormalizationError::UnsupportedLocalRef(reference.to_string()))?; + + if self.ref_stack.iter().any(|seen| seen == name) { + return Err(SchemaNormalizationError::RecursiveReference( + reference.to_string(), + )); + } + + let target = definitions + .get(name) + .cloned() + .ok_or_else(|| SchemaNormalizationError::MissingDefinition(reference.to_string()))?; + + self.ref_stack.push(name.to_string()); + let resolved = self.sanitize_schema_node(target, definitions); + self.ref_stack.pop(); + resolved + } + + fn fallback_schema_for_error(error: SchemaNormalizationError) -> Value { + match error { + // OpenAI rejects recursive refs and some valid local-ref structures. + // When we cannot preserve them safely, fall back to a shallow object + // schema rather than failing the entire request. + SchemaNormalizationError::UnsupportedLocalRef(_) + | SchemaNormalizationError::MissingDefinition(_) + | SchemaNormalizationError::RecursiveReference(_) => generic_object_schema(), + } + } +} + #[derive(Serialize, Deserialize, Debug, Default)] struct DeltaToolCallFunction { name: Option, @@ -649,49 +890,497 @@ fn extract_usage_with_output_tokens(chunk: &StreamingChunk) -> Option Option<&mut Value> { + if tool.get("parameters").is_some() { + tool.get_mut("parameters") + } else { + tool.get_mut("function") + .and_then(|function| function.get_mut("parameters")) + } +} + +fn normalize_json_schema_for_openai(schema: &mut Value) { + // OpenAI rejects several valid JSON Schema constructs (`$ref`, `$defs`, + // `anyOf`/`oneOf`, recursive unions), so collapse schemas into the subset + // it accepts before serializing tool definitions. The transform preserves + // shape where possible and explicitly falls back to a generic object schema + // only at unsupported reference boundaries. + let original = std::mem::take(schema); + let mut normalizer = SchemaNormalizer::default(); + *schema = normalizer.normalize(original); + ensure_valid_json_schema(schema); +} + +fn merge_schema_definitions( + inherited_definitions: &JsonSchemaMap, + schema: &JsonSchemaMap, +) -> JsonSchemaMap { + let mut definitions = inherited_definitions.clone(); + + if let Some(local_defs) = schema.get("$defs").and_then(|value| value.as_object()) { + definitions.extend(local_defs.clone()); + } + + if let Some(local_defs) = schema + .get("definitions") + .and_then(|value| value.as_object()) + { + definitions.extend(local_defs.clone()); + } + + definitions +} + +fn merge_union_schemas(variants: Vec) -> Value { + let variants: Vec = variants + .into_iter() + .filter(|variant| !variant.is_null()) + .collect(); + if variants.is_empty() { + return generic_object_schema(); + } + if variants.len() == 1 { + return variants + .into_iter() + .next() + .unwrap_or_else(generic_object_schema); + } + + let mut merged = serde_json::Map::new(); + + if let Some(description) = variants.iter().find_map(schema_description) { + merged.insert( + "description".to_string(), + Value::String(description.to_string()), + ); + } + + let merged_types = collect_variant_types(&variants); + if !merged_types.is_empty() { + merged.insert("type".to_string(), json_type_value(merged_types.clone())); + } + + if merged_types.iter().any(|ty| ty == "object") { + let object_variants: Vec<&serde_json::Map> = variants + .iter() + .filter_map(|variant| variant.as_object()) + .filter(|variant| schema_allows_type(variant, "object")) + .collect(); + + let mut properties = serde_json::Map::new(); + let mut required_sets = Vec::new(); + let mut all_disallow_additional = !object_variants.is_empty(); + + for variant in object_variants { + if let Some(variant_properties) = variant + .get("properties") + .and_then(|value| value.as_object()) + { + for (name, property_schema) in variant_properties { + if let Some(existing) = properties.get(name).cloned() { + properties.insert( + name.clone(), + merge_union_schemas(vec![existing, property_schema.clone()]), + ); + } else { + properties.insert(name.clone(), property_schema.clone()); + } + } + } + + let required = variant + .get("required") + .and_then(|value| value.as_array()) + .map(|required| { + required + .iter() + .filter_map(|value| value.as_str().map(ToString::to_string)) + .collect::>() + }) + .unwrap_or_default(); + required_sets.push(required); + + if variant.get("additionalProperties") != Some(&Value::Bool(false)) { + all_disallow_additional = false; + } + } + + merged.insert("properties".to_string(), Value::Object(properties)); + merged.insert( + "required".to_string(), + Value::Array( + merge_required_sets(required_sets) + .into_iter() + .map(Value::String) + .collect(), + ), + ); + if all_disallow_additional { + merged.insert("additionalProperties".to_string(), Value::Bool(false)); + } + } + + if merged_types.iter().any(|ty| ty == "array") { + let array_variants: Vec<&serde_json::Map> = variants + .iter() + .filter_map(|variant| variant.as_object()) + .filter(|variant| schema_allows_type(variant, "array")) + .collect(); + + let item_schemas: Vec = array_variants + .iter() + .filter_map(|variant| variant.get("items").cloned()) + .collect(); + if !item_schemas.is_empty() { + merged.insert("items".to_string(), merge_union_schemas(item_schemas)); + } + + if let Some(min_items) = array_variants + .iter() + .filter_map(|variant| variant.get("minItems").and_then(Value::as_u64)) + .min() + { + merged.insert("minItems".to_string(), json!(min_items)); + } + + if let Some(max_items) = array_variants + .iter() + .filter_map(|variant| variant.get("maxItems").and_then(Value::as_u64)) + .max() + { + merged.insert("maxItems".to_string(), json!(max_items)); + } + } + + if let Some(enum_values) = merge_scalar_enums(&variants) { + merged.insert("enum".to_string(), Value::Array(enum_values)); + } + + Value::Object(merged) +} + +fn merge_all_of_schemas(variants: Vec) -> Value { + let variants: Vec = variants + .into_iter() + .filter(|variant| !variant.is_null()) + .collect(); + if variants.is_empty() { + return generic_object_schema(); + } + if variants.len() == 1 { + return variants + .into_iter() + .next() + .unwrap_or_else(generic_object_schema); + } + + let mut merged = serde_json::Map::new(); + + if let Some(description) = variants.iter().find_map(schema_description) { + merged.insert( + "description".to_string(), + Value::String(description.to_string()), + ); + } + + let type_sets: Vec> = variants + .iter() + .map(schema_type_set) + .filter(|types| !types.is_empty()) + .collect(); + + let merged_types = if type_sets.is_empty() { + Vec::new() + } else { + let mut iter = type_sets.into_iter(); + let mut intersection = iter.next().unwrap_or_default(); + for types in iter { + intersection = intersection.intersection(&types).cloned().collect(); + } + if intersection.is_empty() { + collect_variant_types(&variants) + } else { + intersection.into_iter().collect() + } + }; + + if !merged_types.is_empty() { + merged.insert("type".to_string(), json_type_value(merged_types.clone())); + } + + if merged_types.iter().any(|ty| ty == "object") { + let object_variants: Vec<&serde_json::Map> = variants + .iter() + .filter_map(|variant| variant.as_object()) + .filter(|variant| schema_allows_type(variant, "object")) + .collect(); + + let mut properties = serde_json::Map::new(); + let mut required = BTreeSet::new(); + let mut disallow_additional = false; + + for variant in object_variants { + if let Some(variant_properties) = variant + .get("properties") + .and_then(|value| value.as_object()) + { + for (name, property_schema) in variant_properties { + if let Some(existing) = properties.get(name).cloned() { + properties.insert( + name.clone(), + merge_all_of_schemas(vec![existing, property_schema.clone()]), + ); + } else { + properties.insert(name.clone(), property_schema.clone()); + } + } + } + + if let Some(variant_required) = + variant.get("required").and_then(|value| value.as_array()) + { + required.extend( + variant_required + .iter() + .filter_map(|value| value.as_str().map(ToString::to_string)), + ); + } + + if variant.get("additionalProperties") == Some(&Value::Bool(false)) { + disallow_additional = true; + } + } + + merged.insert("properties".to_string(), Value::Object(properties)); + merged.insert( + "required".to_string(), + Value::Array(required.into_iter().map(Value::String).collect()), + ); + if disallow_additional { + merged.insert("additionalProperties".to_string(), Value::Bool(false)); + } + } + + if merged_types.iter().any(|ty| ty == "array") { + let array_variants: Vec<&serde_json::Map> = variants + .iter() + .filter_map(|variant| variant.as_object()) + .filter(|variant| schema_allows_type(variant, "array")) + .collect(); + + let item_schemas: Vec = array_variants + .iter() + .filter_map(|variant| variant.get("items").cloned()) + .collect(); + if !item_schemas.is_empty() { + merged.insert("items".to_string(), merge_all_of_schemas(item_schemas)); + } + + if let Some(min_items) = array_variants + .iter() + .filter_map(|variant| variant.get("minItems").and_then(Value::as_u64)) + .max() + { + merged.insert("minItems".to_string(), json!(min_items)); + } + + if let Some(max_items) = array_variants + .iter() + .filter_map(|variant| variant.get("maxItems").and_then(Value::as_u64)) + .min() + { + merged.insert("maxItems".to_string(), json!(max_items)); + } + } + + Value::Object(merged) +} + +fn merge_required_sets(required_sets: Vec>) -> Vec { + let mut iter = required_sets.into_iter(); + let Some(first) = iter.next() else { + return Vec::new(); + }; + + let merged = iter.fold(first, |current, required| { + current.intersection(&required).cloned().collect() + }); + + merged.into_iter().collect() +} + +fn merge_scalar_enums(variants: &[Value]) -> Option> { + let enum_values: Vec = variants + .iter() + .filter_map(|variant| variant.get("enum").and_then(|value| value.as_array())) + .flatten() + .cloned() + .collect(); + + if enum_values.is_empty() { + None + } else { + Some(dedup_json_values(enum_values)) + } +} + +fn dedup_json_values(values: Vec) -> Vec { + let mut deduped = Vec::new(); + for value in values { + if !deduped.iter().any(|existing| existing == &value) { + deduped.push(value); + } + } + deduped +} + +fn schema_type_set(schema: &Value) -> BTreeSet { + schema + .as_object() + .map(collect_schema_types) + .unwrap_or_default() + .into_iter() + .collect() +} + +fn collect_variant_types(variants: &[Value]) -> Vec { + let mut types = Vec::new(); + for variant in variants { + if let Some(object) = variant.as_object() { + for ty in collect_schema_types(object) { + push_unique_type(&mut types, &ty); + } + } + } + types +} + +fn collect_schema_types(schema: &serde_json::Map) -> Vec { + let mut types = Vec::new(); + + match schema.get("type") { + Some(Value::String(ty)) if ty != "null" => push_unique_type(&mut types, ty), + Some(Value::Array(type_values)) => { + for ty in type_values.iter().filter_map(|value| value.as_str()) { + if ty != "null" { + push_unique_type(&mut types, ty); } } } + _ => { + if schema.contains_key("properties") + || schema.contains_key("required") + || schema.contains_key("additionalProperties") + { + push_unique_type(&mut types, "object"); + } + if schema.contains_key("items") { + push_unique_type(&mut types, "array"); + } + } } + + types +} + +fn normalized_type_value_from_schema(schema: &serde_json::Map) -> Option { + let types = collect_schema_types(schema); + if types.is_empty() { + None + } else { + Some(json_type_value(types)) + } +} + +fn json_type_value(types: Vec) -> Value { + if types.len() == 1 { + Value::String(types[0].clone()) + } else { + Value::Array(types.into_iter().map(Value::String).collect()) + } +} + +fn push_unique_type(types: &mut Vec, ty: &str) { + if !types.iter().any(|existing| existing == ty) { + types.push(ty.to_string()); + } +} + +fn schema_allows_type(schema: &serde_json::Map, expected: &str) -> bool { + collect_schema_types(schema) + .into_iter() + .any(|ty| ty == expected) +} + +fn schema_description(schema: &Value) -> Option<&str> { + schema + .as_object() + .and_then(|object| object.get("description")) + .and_then(|value| value.as_str()) +} + +fn is_null_only_schema(schema: &serde_json::Map) -> bool { + match schema.get("type") { + Some(Value::String(ty)) => ty == "null", + Some(Value::Array(types)) => { + !types.is_empty() && types.iter().all(|value| value.as_str() == Some("null")) + } + _ => false, + } +} + +fn copy_schema_keyword( + source: &serde_json::Map, + target: &mut serde_json::Map, + key: &str, +) { + if let Some(value) = source.get(key).cloned() { + target.insert(key.to_string(), value); + } +} + +fn generic_object_schema() -> Value { + json!({ + "type": "object", + "properties": {}, + "required": [] + }) } /// Ensures that the given JSON value follows the expected JSON Schema structure. fn ensure_valid_json_schema(schema: &mut Value) { if let Some(params_obj) = schema.as_object_mut() { - // Check if this is meant to be an object type schema - let is_object_type = params_obj - .get("type") - .and_then(|t| t.as_str()) - .is_none_or(|t| t == "object"); // Default to true if no type is specified - - // Only apply full schema validation to object types - if is_object_type { - // Ensure required fields exist with default values + if schema_allows_type(params_obj, "object") { params_obj.entry("properties").or_insert_with(|| json!({})); params_obj.entry("required").or_insert_with(|| json!([])); - params_obj.entry("type").or_insert_with(|| json!("object")); - // Recursively validate properties if it exists if let Some(properties) = params_obj.get_mut("properties") { if let Some(properties_obj) = properties.as_object_mut() { - for (_key, prop) in properties_obj.iter_mut() { - if prop.is_object() - && prop.get("type").and_then(|t| t.as_str()) == Some("object") - { - ensure_valid_json_schema(prop); - } + for property_schema in properties_obj.values_mut() { + ensure_valid_json_schema(property_schema); } } } } + + if schema_allows_type(params_obj, "array") { + if let Some(items) = params_obj.get_mut("items") { + ensure_valid_json_schema(items); + } + } } } @@ -1033,13 +1722,24 @@ pub fn create_request( mod tests { use super::*; use crate::conversation::message::Message; - use rmcp::model::CallToolResult; + use rmcp::model::{CallToolResult, Tool}; use rmcp::object; use serde_json::json; use test_case::test_case; use tokio::pin; use tokio_stream::{self, StreamExt}; + fn schema_contains_key(value: &Value, needle: &str) -> bool { + match value { + Value::Object(map) => { + map.contains_key(needle) + || map.values().any(|child| schema_contains_key(child, needle)) + } + Value::Array(items) => items.iter().any(|child| schema_contains_key(child, needle)), + _ => false, + } + } + #[test] fn test_validate_tool_schemas() { // Test case 1: Empty parameters object @@ -1116,6 +1816,139 @@ mod tests { assert_eq!(tools[0], original_schema); } + #[test] + fn test_validate_tool_schemas_flattens_refs_and_any_of() { + let mut tools = vec![json!({ + "type": "function", + "function": { + "name": "render_donut", + "description": "Render a donut chart", + "parameters": { + "$defs": { + "DonutValueItem": { + "anyOf": [ + { "type": "number" }, + { + "type": "object", + "properties": { + "label": { "type": "string" }, + "value": { "type": "number" } + }, + "required": ["label", "value"], + "additionalProperties": false + } + ] + }, + "SingleDonutChart": { + "type": "object", + "properties": { + "values": { + "type": "array", + "items": { "$ref": "#/$defs/DonutValueItem" } + }, + "labels": { + "type": ["array", "null"], + "items": { "type": "string" } + } + }, + "required": ["values"], + "additionalProperties": false + } + }, + "type": "object", + "properties": { + "data": { + "anyOf": [ + { "$ref": "#/$defs/SingleDonutChart" }, + { + "type": "array", + "items": { "$ref": "#/$defs/SingleDonutChart" }, + "minItems": 1 + } + ] + } + }, + "required": ["data"], + "additionalProperties": false + } + } + })]; + + validate_tool_schemas(&mut tools); + + let parameters = &tools[0]["function"]["parameters"]; + assert!(!schema_contains_key(parameters, "$defs")); + assert!(!schema_contains_key(parameters, "definitions")); + assert!(!schema_contains_key(parameters, "$ref")); + assert!(!schema_contains_key(parameters, "anyOf")); + assert_eq!( + parameters["properties"]["data"]["type"], + json!(["object", "array"]) + ); + assert_eq!(parameters["properties"]["data"]["items"]["type"], "object"); + assert_eq!( + parameters["properties"]["data"]["items"]["properties"]["values"]["items"]["type"], + json!(["number", "object"]) + ); + } + + #[test] + fn test_validate_tool_schemas_breaks_recursive_refs() { + let mut tools = vec![json!({ + "type": "function", + "parameters": { + "$defs": { + "TreemapNode": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "value": { "type": ["number", "null"] }, + "children": { + "type": ["array", "null"], + "items": { "$ref": "#/$defs/TreemapNode" } + } + }, + "required": ["name"], + "additionalProperties": false + } + }, + "$ref": "#/$defs/TreemapNode" + } + })]; + + validate_tool_schemas(&mut tools); + + let parameters = &tools[0]["parameters"]; + assert!(!schema_contains_key(parameters, "$defs")); + assert!(!schema_contains_key(parameters, "$ref")); + assert_eq!(parameters["type"], "object"); + assert_eq!(parameters["properties"]["children"]["type"], "array"); + assert_eq!( + parameters["properties"]["children"]["items"]["type"], + "object" + ); + assert!(parameters["properties"]["children"]["items"]["properties"] + .as_object() + .is_some()); + } + + #[test] + fn test_validate_tool_schemas_falls_back_for_missing_ref() { + let mut tools = vec![json!({ + "type": "function", + "parameters": { + "$ref": "#/$defs/DoesNotExist" + } + })]; + + validate_tool_schemas(&mut tools); + + let parameters = &tools[0]["parameters"]; + assert_eq!(parameters["type"], "object"); + assert_eq!(parameters["properties"], json!({})); + assert_eq!(parameters["required"], json!([])); + } + const OPENAI_TOOL_USE_RESPONSE: &str = r#"{ "choices": [{ "role": "assistant", @@ -1791,6 +2624,73 @@ mod tests { Ok(()) } + #[test] + fn test_create_request_sanitizes_tool_schema() -> anyhow::Result<()> { + let model_config = ModelConfig { + model_name: "gpt-4o".to_string(), + context_limit: Some(4096), + temperature: None, + max_tokens: Some(1024), + toolshim: false, + toolshim_model: None, + fast_model_config: None, + request_params: None, + reasoning: None, + }; + + let tool = Tool::new( + "render_donut", + "Render a donut chart", + object!({ + "$defs": { + "DonutValueItem": { + "anyOf": [ + { "type": "number" }, + { + "type": "object", + "properties": { + "label": { "type": "string" }, + "value": { "type": "number" } + }, + "required": ["label", "value"], + "additionalProperties": false + } + ] + } + }, + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { "$ref": "#/$defs/DonutValueItem" } + } + }, + "required": ["data"], + "additionalProperties": false + }), + ); + + let request = create_request( + &model_config, + "system", + &[], + &[tool], + &ImageFormat::OpenAi, + false, + )?; + let parameters = &request["tools"][0]["function"]["parameters"]; + + assert!(!schema_contains_key(parameters, "$defs")); + assert!(!schema_contains_key(parameters, "$ref")); + assert!(!schema_contains_key(parameters, "anyOf")); + assert_eq!( + parameters["properties"]["data"]["items"]["type"], + json!(["number", "object"]) + ); + + Ok(()) + } + struct StreamingUsageTestResult { usage_count: usize, usage: Option, diff --git a/crates/goose/src/providers/formats/openai_responses.rs b/crates/goose/src/providers/formats/openai_responses.rs index 68e0f785f549..85a469b6963e 100644 --- a/crates/goose/src/providers/formats/openai_responses.rs +++ b/crates/goose/src/providers/formats/openai_responses.rs @@ -2,6 +2,7 @@ use crate::conversation::message::{Message, MessageContent}; use crate::mcp_utils::extract_text_from_resource; use crate::model::ModelConfig; use crate::providers::base::{ProviderUsage, Usage}; +use crate::providers::formats::openai::validate_tool_schemas; use crate::providers::utils::extract_reasoning_effort; use anyhow::{anyhow, Error}; use async_stream::try_stream; @@ -487,7 +488,7 @@ pub fn create_responses_request( } if !tools.is_empty() { - let tools_spec: Vec = tools + let mut tools_spec: Vec = tools .iter() .map(|tool| { json!({ @@ -499,6 +500,8 @@ pub fn create_responses_request( }) .collect(); + validate_tool_schemas(&mut tools_spec); + payload .as_object_mut() .unwrap() @@ -808,9 +811,20 @@ mod tests { use crate::conversation::message::MessageContent; use crate::model::ModelConfig; use futures::StreamExt; - use rmcp::model::CallToolRequestParams; + use rmcp::model::{CallToolRequestParams, Tool}; use rmcp::object; + fn schema_contains_key(value: &Value, needle: &str) -> bool { + match value { + Value::Object(map) => { + map.contains_key(needle) + || map.values().any(|child| schema_contains_key(child, needle)) + } + Value::Array(items) => items.iter().any(|child| schema_contains_key(child, needle)), + _ => false, + } + } + #[tokio::test] async fn test_responses_stream_ignores_keepalive_event() -> anyhow::Result<()> { let lines = vec![ @@ -1083,4 +1097,63 @@ mod tests { assert_eq!(info.effort.as_deref(), Some("high")); assert_eq!(info.summary.as_deref(), Some("Thought deeply")); } + + #[test] + fn test_responses_request_sanitizes_tool_schema() { + let model_config = ModelConfig { + model_name: "gpt-5.2-codex".to_string(), + context_limit: None, + temperature: None, + max_tokens: None, + toolshim: false, + toolshim_model: None, + fast_model_config: None, + request_params: None, + reasoning: None, + }; + + let tool = Tool::new( + "show_chart", + "Show a chart", + object!({ + "$defs": { + "ChartPoint": { + "type": "object", + "properties": { + "x": { "type": "number" }, + "y": { "type": "number" } + }, + "required": ["x", "y"], + "additionalProperties": false + } + }, + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "anyOf": [ + { "type": "number" }, + { "$ref": "#/$defs/ChartPoint" } + ] + } + } + }, + "required": ["data"], + "additionalProperties": false + }), + ); + + let payload = + create_responses_request(&model_config, "system prompt", &[], &[tool]).unwrap(); + let parameters = &payload["tools"][0]["parameters"]; + + assert!(!schema_contains_key(parameters, "$defs")); + assert!(!schema_contains_key(parameters, "$ref")); + assert!(!schema_contains_key(parameters, "anyOf")); + assert_eq!( + parameters["properties"]["data"]["items"]["type"], + json!(["number", "object"]) + ); + } }