Skip to content

Commit 7ec07f2

Browse files
fresh3noughDouwe Osinga
andauthored
fix: prevent Ollama provider from hanging on tool-calling requests (#7723)
Signed-off-by: fre <anonwurcod@proton.me> Signed-off-by: fresh3nough <nicholasanthony742@gmail.com> Signed-off-by: Douwe Osinga <douwe@squareup.com> Co-authored-by: Douwe Osinga <douwe@squareup.com>
1 parent c936514 commit 7ec07f2

3 files changed

Lines changed: 204 additions & 18 deletions

File tree

Cargo.lock

Lines changed: 13 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/goose/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ env-lock = { workspace = true }
215215
rmcp = { workspace = true, features = ["transport-streamable-http-server"] }
216216
opentelemetry_sdk = { workspace = true, features = ["testing"] }
217217
goose-test-support = { path = "../goose-test-support" }
218+
bytes.workspace = true
219+
http.workspace = true
218220

219221
[[example]]
220222
name = "agent"

crates/goose/src/providers/ollama.rs

Lines changed: 189 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,28 @@ fn resolve_ollama_num_ctx(model_config: &ModelConfig) -> Option<usize> {
6868
}
6969

7070
fn apply_ollama_options(payload: &mut Value, model_config: &ModelConfig) {
71-
let Some(limit) = resolve_ollama_num_ctx(model_config) else {
72-
return;
73-
};
74-
7571
if let Some(obj) = payload.as_object_mut() {
76-
let options = obj.entry("options").or_insert_with(|| json!({}));
77-
if let Some(options_obj) = options.as_object_mut() {
78-
options_obj.insert("num_ctx".to_string(), json!(limit));
72+
// Ollama does not support stream_options; remove it to prevent hangs.
73+
obj.remove("stream_options");
74+
75+
// Convert max_completion_tokens / max_tokens to Ollama's options.num_predict.
76+
// Reasoning models emit max_completion_tokens; non-reasoning models emit max_tokens.
77+
let max_tokens = obj
78+
.remove("max_completion_tokens")
79+
.or_else(|| obj.remove("max_tokens"));
80+
if let Some(max_tokens) = max_tokens {
81+
let options = obj.entry("options").or_insert_with(|| json!({}));
82+
if let Some(options_obj) = options.as_object_mut() {
83+
options_obj.entry("num_predict").or_insert(max_tokens);
84+
}
85+
}
86+
87+
// Apply num_ctx from context limit settings.
88+
if let Some(limit) = resolve_ollama_num_ctx(model_config) {
89+
let options = obj.entry("options").or_insert_with(|| json!({}));
90+
if let Some(options_obj) = options.as_object_mut() {
91+
options_obj.insert("num_ctx".to_string(), json!(limit));
92+
}
7993
}
8094
}
8195
}
@@ -300,9 +314,49 @@ impl Provider for OllamaProvider {
300314
}
301315
}
302316

317+
/// Per-chunk timeout for Ollama streaming responses.
318+
/// If no new raw SSE data arrives within this duration, the connection is considered dead.
319+
const OLLAMA_CHUNK_TIMEOUT_SECS: u64 = 30;
320+
321+
/// Wraps a line stream with a per-item timeout at the raw SSE level.
322+
/// This detects dead connections without false-positive stalls during long
323+
/// tool-call generations where response_to_streaming_message_ollama buffers.
324+
fn with_line_timeout(
325+
stream: impl futures::Stream<Item = anyhow::Result<String>> + Unpin + Send + 'static,
326+
timeout_secs: u64,
327+
) -> std::pin::Pin<Box<dyn futures::Stream<Item = anyhow::Result<String>> + Send>> {
328+
let timeout = Duration::from_secs(timeout_secs);
329+
Box::pin(try_stream! {
330+
let mut stream = stream;
331+
332+
// Allow time-to-first-token to be governed by the request timeout.
333+
// Only enforce per-chunk timeout after first SSE line arrives.
334+
match stream.next().await {
335+
Some(first_item) => yield first_item?,
336+
None => return,
337+
}
338+
loop {
339+
match tokio::time::timeout(timeout, stream.next()).await {
340+
Ok(Some(item)) => yield item?,
341+
Ok(None) => break,
342+
Err(_) => {
343+
Err::<(), anyhow::Error>(anyhow::anyhow!(
344+
"Ollama stream stalled: no data received for {}s. \
345+
This may indicate the model is overwhelmed by the request payload. \
346+
Try a smaller model or reduce the number of tools.",
347+
timeout_secs
348+
))?;
349+
}
350+
}
351+
}
352+
})
353+
}
354+
303355
/// Ollama-specific streaming handler with XML tool call fallback.
304356
/// Uses the Ollama format module which buffers text when XML tool calls are detected,
305357
/// preventing duplicate content from being emitted to the UI.
358+
/// Timeout is applied at the raw SSE line level via with_line_timeout so that
359+
/// buffering inside response_to_streaming_message_ollama does not cause false stalls.
306360
fn stream_ollama(response: Response, mut log: RequestLog) -> Result<MessageStream, ProviderError> {
307361
let stream = response.bytes_stream().map_err(std::io::Error::other);
308362

@@ -311,8 +365,10 @@ fn stream_ollama(response: Response, mut log: RequestLog) -> Result<MessageStrea
311365
let framed = FramedRead::new(stream_reader, LinesCodec::new())
312366
.map_err(Error::from);
313367

314-
let message_stream = response_to_streaming_message_ollama(framed);
368+
let timed_lines = with_line_timeout(framed, OLLAMA_CHUNK_TIMEOUT_SECS);
369+
let message_stream = response_to_streaming_message_ollama(timed_lines);
315370
pin!(message_stream);
371+
316372
while let Some(message) = message_stream.next().await {
317373
let (message, usage) = message.map_err(|e|
318374
ProviderError::RequestFailed(format!("Stream decode error: {}", e))
@@ -359,6 +415,131 @@ mod tests {
359415
assert!(payload.get("options").is_none());
360416
}
361417

418+
#[test]
419+
fn test_raw_create_request_contains_unsupported_ollama_fields() {
420+
use crate::providers::formats::ollama::create_request;
421+
use crate::providers::utils::ImageFormat;
422+
423+
let model_config = ModelConfig::new("llama3.1")
424+
.unwrap()
425+
.with_max_tokens(Some(4096));
426+
let messages = vec![crate::conversation::message::Message::user().with_text("hi")];
427+
428+
let payload = create_request(
429+
&model_config,
430+
"You are a helpful assistant.",
431+
&messages,
432+
&[],
433+
&ImageFormat::OpenAi,
434+
true,
435+
)
436+
.unwrap();
437+
438+
assert!(
439+
payload.get("stream_options").is_some(),
440+
"create_request should produce stream_options (unsupported by Ollama)"
441+
);
442+
assert!(
443+
payload.get("max_tokens").is_some(),
444+
"create_request should produce max_tokens (unsupported by Ollama)"
445+
);
446+
}
447+
448+
#[test]
449+
fn test_apply_ollama_options_strips_unsupported_fields() {
450+
use crate::providers::formats::ollama::create_request;
451+
use crate::providers::utils::ImageFormat;
452+
453+
let _guard = env_lock::lock_env([("GOOSE_INPUT_LIMIT", None::<&str>)]);
454+
let model_config = ModelConfig::new("llama3.1")
455+
.unwrap()
456+
.with_max_tokens(Some(4096));
457+
let messages = vec![crate::conversation::message::Message::user().with_text("hi")];
458+
459+
let mut payload = create_request(
460+
&model_config,
461+
"You are a helpful assistant.",
462+
&messages,
463+
&[],
464+
&ImageFormat::OpenAi,
465+
true,
466+
)
467+
.unwrap();
468+
469+
apply_ollama_options(&mut payload, &model_config);
470+
471+
assert!(
472+
payload.get("stream_options").is_none(),
473+
"stream_options should be removed for Ollama"
474+
);
475+
assert!(
476+
payload.get("max_tokens").is_none(),
477+
"max_tokens should be removed for Ollama"
478+
);
479+
assert!(
480+
payload.get("max_completion_tokens").is_none(),
481+
"max_completion_tokens should be removed for Ollama"
482+
);
483+
assert_eq!(
484+
payload["options"]["num_predict"], 4096,
485+
"max_tokens should be moved to options.num_predict"
486+
);
487+
assert_eq!(payload["stream"], true, "stream field should be preserved");
488+
}
489+
490+
#[tokio::test]
491+
async fn test_stream_ollama_timeout_on_stall() {
492+
use std::convert::Infallible;
493+
494+
let (tx, rx) = tokio::sync::mpsc::channel::<Result<bytes::Bytes, Infallible>>(1);
495+
tx.send(Ok(bytes::Bytes::from(
496+
"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"},\"index\":0}],\
497+
\"model\":\"test\",\"object\":\"chat.completion.chunk\",\"created\":0}\n",
498+
)))
499+
.await
500+
.unwrap();
501+
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
502+
let body = reqwest::Body::wrap_stream(stream);
503+
let response = http::Response::builder().status(200).body(body).unwrap();
504+
let response: reqwest::Response = response.into();
505+
506+
let log = RequestLog::start(
507+
&ModelConfig::new("test").unwrap(),
508+
&json!({"model": "test"}),
509+
)
510+
.unwrap();
511+
512+
let mut msg_stream = stream_ollama(response, log).unwrap();
513+
514+
let result =
515+
tokio::time::timeout(Duration::from_secs(OLLAMA_CHUNK_TIMEOUT_SECS + 5), async {
516+
let mut last_err = None;
517+
while let Some(item) = msg_stream.next().await {
518+
if let Err(e) = item {
519+
last_err = Some(e);
520+
break;
521+
}
522+
}
523+
last_err
524+
})
525+
.await;
526+
527+
match result {
528+
Ok(Some(err)) => {
529+
let err_msg = err.to_string();
530+
assert!(
531+
err_msg.contains("stream stalled"),
532+
"Expected stall timeout error, got: {}",
533+
err_msg
534+
);
535+
}
536+
Ok(None) => panic!("Expected timeout error but stream completed normally"),
537+
Err(_) => panic!("Outer timeout elapsed -- per-chunk timeout did not fire"),
538+
}
539+
540+
drop(tx);
541+
}
542+
362543
#[test]
363544
fn test_ollama_retry_config_is_transient_only() {
364545
let config = RetryConfig::new(

0 commit comments

Comments
 (0)