Skip to content

Commit 675c828

Browse files
committed
fix: remap max_completion_tokens to max_tokens for OpenAI-compatible providers
When using OpenAI-compatible providers like Mistral via the declarative provider system, the OpenAI-specific parameter max_completion_tokens (used for O-series models) is not recognized and causes 422 errors. This adds a sanitize_request_for_compat method to OpenAiProvider that remaps max_completion_tokens to max_tokens for any non-native OpenAI provider, ensuring compatibility with Mistral and other OpenAI-compatible APIs. Closes #7762 Signed-off-by: fre <anonwurcod@proton.me>
1 parent 12eac72 commit 675c828

1 file changed

Lines changed: 91 additions & 1 deletion

File tree

crates/goose/src/providers/openai.rs

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,20 @@ impl OpenAiProvider {
251251
Self::is_responses_model(model_name)
252252
}
253253

254+
fn sanitize_request_for_compat(&self, mut payload: serde_json::Value) -> serde_json::Value {
255+
if self.name == OPEN_AI_PROVIDER_NAME {
256+
return payload;
257+
}
258+
259+
if let Some(obj) = payload.as_object_mut() {
260+
if let Some(value) = obj.remove("max_completion_tokens") {
261+
obj.entry("max_tokens").or_insert(value);
262+
}
263+
}
264+
265+
payload
266+
}
267+
254268
fn map_base_path(base_path: &str, target: &str, fallback: &str) -> String {
255269
let normalized = Self::normalize_base_path(base_path);
256270
if normalized.ends_with(target) || normalized.contains(&format!("/{target}")) {
@@ -457,6 +471,7 @@ impl Provider for OpenAiProvider {
457471
&ImageFormat::OpenAi,
458472
self.supports_streaming,
459473
)?;
474+
let payload = self.sanitize_request_for_compat(payload);
460475
let mut log = RequestLog::start(model_config, &payload)?;
461476

462477
let response = self
@@ -568,7 +583,82 @@ impl EmbeddingCapable for OpenAiProvider {
568583

569584
#[cfg(test)]
570585
mod tests {
571-
use super::OpenAiProvider;
586+
use super::*;
587+
use serde_json::json;
588+
589+
fn make_provider(name: &str) -> OpenAiProvider {
590+
OpenAiProvider {
591+
api_client: ApiClient::new("http://localhost".to_string(), AuthMethod::NoAuth).unwrap(),
592+
base_path: "v1/chat/completions".to_string(),
593+
organization: None,
594+
project: None,
595+
model: ModelConfig::new_or_fail("test-model"),
596+
custom_headers: None,
597+
supports_streaming: true,
598+
name: name.to_string(),
599+
}
600+
}
601+
602+
#[test]
603+
fn sanitize_remaps_max_completion_tokens_for_compat_provider() {
604+
let provider = make_provider("mistral");
605+
let payload = json!({
606+
"model": "mistral-medium-latest",
607+
"messages": [],
608+
"max_completion_tokens": 16384
609+
});
610+
611+
let result = provider.sanitize_request_for_compat(payload);
612+
let obj = result.as_object().unwrap();
613+
614+
assert!(!obj.contains_key("max_completion_tokens"));
615+
assert_eq!(obj.get("max_tokens").unwrap(), &json!(16384));
616+
}
617+
618+
#[test]
619+
fn sanitize_preserves_existing_max_tokens_for_compat_provider() {
620+
let provider = make_provider("mistral");
621+
let payload = json!({
622+
"model": "mistral-medium-latest",
623+
"messages": [],
624+
"max_tokens": 4096,
625+
"max_completion_tokens": 16384
626+
});
627+
628+
let result = provider.sanitize_request_for_compat(payload);
629+
let obj = result.as_object().unwrap();
630+
631+
assert!(!obj.contains_key("max_completion_tokens"));
632+
assert_eq!(obj.get("max_tokens").unwrap(), &json!(4096));
633+
}
634+
635+
#[test]
636+
fn sanitize_noop_for_native_openai_provider() {
637+
let provider = make_provider("openai");
638+
let payload = json!({
639+
"model": "o3",
640+
"messages": [],
641+
"max_completion_tokens": 16384
642+
});
643+
644+
let result = provider.sanitize_request_for_compat(payload);
645+
let obj = result.as_object().unwrap();
646+
647+
assert!(obj.contains_key("max_completion_tokens"));
648+
assert!(!obj.contains_key("max_tokens"));
649+
}
650+
651+
#[test]
652+
fn sanitize_no_token_params() {
653+
let provider = make_provider("groq");
654+
let payload = json!({
655+
"model": "llama-3.3-70b-versatile",
656+
"messages": []
657+
});
658+
659+
let result = provider.sanitize_request_for_compat(payload.clone());
660+
assert_eq!(result, payload);
661+
}
572662

573663
#[test]
574664
fn gpt_5_2_codex_uses_responses_when_base_path_is_default() {

0 commit comments

Comments
 (0)