Skip to content

Commit 3016fdd

Browse files
authored
fix: remap max_completion_tokens to max_tokens for OpenAI-compatible providers (#7765)
Signed-off-by: fre <anonwurcod@proton.me>
1 parent 5b28f8f commit 3016fdd

1 file changed

Lines changed: 121 additions & 1 deletion

File tree

crates/goose/src/providers/openai.rs

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,34 @@ impl OpenAiProvider {
251251
Self::is_responses_model(model_name)
252252
}
253253

254+
/// Providers known to reject `max_completion_tokens` and require
255+
/// the legacy `max_tokens` field instead.
256+
const PROVIDERS_NEEDING_MAX_TOKENS_REMAP: &[&str] = &[
257+
"cerebras",
258+
"custom_deepseek",
259+
"groq",
260+
"inception",
261+
"kimi",
262+
"lmstudio",
263+
"mistral",
264+
"moonshot",
265+
"ovhcloud",
266+
];
267+
268+
fn sanitize_request_for_compat(&self, mut payload: serde_json::Value) -> serde_json::Value {
269+
if !Self::PROVIDERS_NEEDING_MAX_TOKENS_REMAP.contains(&self.name.as_str()) {
270+
return payload;
271+
}
272+
273+
if let Some(obj) = payload.as_object_mut() {
274+
if let Some(value) = obj.remove("max_completion_tokens") {
275+
obj.entry("max_tokens").or_insert(value);
276+
}
277+
}
278+
279+
payload
280+
}
281+
254282
fn map_base_path(base_path: &str, target: &str, fallback: &str) -> String {
255283
let normalized = Self::normalize_base_path(base_path);
256284
if normalized.ends_with(target) || normalized.contains(&format!("/{target}")) {
@@ -457,6 +485,7 @@ impl Provider for OpenAiProvider {
457485
&ImageFormat::OpenAi,
458486
self.supports_streaming,
459487
)?;
488+
let payload = self.sanitize_request_for_compat(payload);
460489
let mut log = RequestLog::start(model_config, &payload)?;
461490

462491
let response = self
@@ -568,7 +597,98 @@ impl EmbeddingCapable for OpenAiProvider {
568597

569598
#[cfg(test)]
570599
mod tests {
571-
use super::OpenAiProvider;
600+
use super::*;
601+
use serde_json::json;
602+
603+
fn make_provider(name: &str) -> OpenAiProvider {
604+
OpenAiProvider {
605+
api_client: ApiClient::new("http://localhost".to_string(), AuthMethod::NoAuth).unwrap(),
606+
base_path: "v1/chat/completions".to_string(),
607+
organization: None,
608+
project: None,
609+
model: ModelConfig::new_or_fail("test-model"),
610+
custom_headers: None,
611+
supports_streaming: true,
612+
name: name.to_string(),
613+
}
614+
}
615+
616+
#[test]
617+
fn sanitize_remaps_max_completion_tokens_for_compat_provider() {
618+
let provider = make_provider("mistral");
619+
let payload = json!({
620+
"model": "mistral-medium-latest",
621+
"messages": [],
622+
"max_completion_tokens": 16384
623+
});
624+
625+
let result = provider.sanitize_request_for_compat(payload);
626+
let obj = result.as_object().unwrap();
627+
628+
assert!(!obj.contains_key("max_completion_tokens"));
629+
assert_eq!(obj.get("max_tokens").unwrap(), &json!(16384));
630+
}
631+
632+
#[test]
633+
fn sanitize_preserves_existing_max_tokens_for_compat_provider() {
634+
let provider = make_provider("mistral");
635+
let payload = json!({
636+
"model": "mistral-medium-latest",
637+
"messages": [],
638+
"max_tokens": 4096,
639+
"max_completion_tokens": 16384
640+
});
641+
642+
let result = provider.sanitize_request_for_compat(payload);
643+
let obj = result.as_object().unwrap();
644+
645+
assert!(!obj.contains_key("max_completion_tokens"));
646+
assert_eq!(obj.get("max_tokens").unwrap(), &json!(4096));
647+
}
648+
649+
#[test]
650+
fn sanitize_noop_for_native_openai_provider() {
651+
let provider = make_provider("openai");
652+
let payload = json!({
653+
"model": "o3",
654+
"messages": [],
655+
"max_completion_tokens": 16384
656+
});
657+
658+
let result = provider.sanitize_request_for_compat(payload);
659+
let obj = result.as_object().unwrap();
660+
661+
assert!(obj.contains_key("max_completion_tokens"));
662+
assert!(!obj.contains_key("max_tokens"));
663+
}
664+
665+
#[test]
666+
fn sanitize_noop_for_unknown_provider() {
667+
let provider = make_provider("some_future_provider");
668+
let payload = json!({
669+
"model": "future-model",
670+
"messages": [],
671+
"max_completion_tokens": 16384
672+
});
673+
674+
let result = provider.sanitize_request_for_compat(payload);
675+
let obj = result.as_object().unwrap();
676+
677+
assert!(obj.contains_key("max_completion_tokens"));
678+
assert!(!obj.contains_key("max_tokens"));
679+
}
680+
681+
#[test]
682+
fn sanitize_no_token_params() {
683+
let provider = make_provider("groq");
684+
let payload = json!({
685+
"model": "llama-3.3-70b-versatile",
686+
"messages": []
687+
});
688+
689+
let result = provider.sanitize_request_for_compat(payload.clone());
690+
assert_eq!(result, payload);
691+
}
572692

573693
#[test]
574694
fn gpt_5_2_codex_uses_responses_when_base_path_is_default() {

0 commit comments

Comments
 (0)