diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs index adf7403d03..a93f966bb4 100644 --- a/src/agent/dispatcher.rs +++ b/src/agent/dispatcher.rs @@ -625,6 +625,31 @@ impl<'a> LoopDelegate for ChatDelegate<'a> { } let llm_call_start = std::time::Instant::now(); + + // Wire up real-time token streaming to the channel layer. + // Bounded channel bounds memory usage when the consumer (channel + // layer) is slower than the LLM; producer drops chunks on overflow + // via `try_send`. + { + let (chunk_tx, mut chunk_rx) = + tokio::sync::mpsc::channel::(256); + let channels = Arc::clone(&self.agent.channels); + let channel_name = self.message.channel.clone(); + let metadata = self.message.metadata.clone(); + tokio::spawn(async move { + while let Some(chunk) = chunk_rx.recv().await { + let _ = channels + .send_status( + &channel_name, + crate::channels::StatusUpdate::StreamChunk(chunk), + &metadata, + ) + .await; + } + }); + reason_ctx.chunk_sender = Some(chunk_tx); + } + let output = match reasoning.respond_with_tools(reason_ctx).await { Ok(output) => output, Err(crate::error::LlmError::ContextLengthExceeded { used, limit }) => { diff --git a/src/llm/circuit_breaker.rs b/src/llm/circuit_breaker.rs index 4740fe02f0..211eb2b1ec 100644 --- a/src/llm/circuit_breaker.rs +++ b/src/llm/circuit_breaker.rs @@ -292,6 +292,42 @@ impl LlmProvider for CircuitBreakerProvider { } } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.check_allowed().await?; + match self.inner.complete_stream(request, on_chunk).await { + Ok(resp) => { + self.record_success().await; + Ok(resp) + } + Err(err) => { + self.record_failure(&err).await; + Err(err) + } + } + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.check_allowed().await?; + match self.inner.complete_with_tools_stream(request, on_chunk).await { + Ok(resp) => { + self.record_success().await; + Ok(resp) + } + Err(err) => { + self.record_failure(&err).await; + Err(err) + } + } + } + async fn list_models(&self) -> Result, LlmError> { self.inner.list_models().await } diff --git a/src/llm/failover.rs b/src/llm/failover.rs index a23934d1e7..02e6ef7a2f 100644 --- a/src/llm/failover.rs +++ b/src/llm/failover.rs @@ -329,6 +329,26 @@ impl LlmProvider for FailoverProvider { Ok(response) } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.providers[self.last_used.load(Ordering::Relaxed)] + .complete_stream(request, on_chunk) + .await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.providers[self.last_used.load(Ordering::Relaxed)] + .complete_with_tools_stream(request, on_chunk) + .await + } + fn active_model_name(&self) -> String { self.providers[self.last_used.load(Ordering::Relaxed)].active_model_name() } diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 2019d23455..d641ab60a2 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -25,6 +25,7 @@ mod nearai_chat; pub mod oauth_helpers; pub mod openai_codex_provider; pub mod openai_codex_session; +mod openai_compat_stream; mod provider; mod reasoning; pub mod recording; @@ -293,7 +294,7 @@ fn create_openai_compat_from_registry( builder = builder.base_url(&base_url); } if !extra_headers.is_empty() { - builder = builder.http_headers(extra_headers); + builder = builder.http_headers(extra_headers.clone()); } let client: openai::Client = builder.build().map_err(|e| LlmError::RequestFailed { @@ -316,7 +317,41 @@ fn create_openai_compat_from_registry( let adapter = RigAdapter::new(model, &config.model) .with_unsupported_params(config.unsupported_params.clone()); - Ok(Arc::new(adapter)) + // Re-use the already-validated header map: iterate it to build the + // (String, String) pairs for the streaming provider, skipping any that + // produced warnings above. + let extra_headers_vec: Vec<(String, String)> = extra_headers + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.as_str().to_string(), v.to_string())) + }) + .collect(); + let unsupported: std::collections::HashSet = + config.unsupported_params.iter().cloned().collect(); + // Normalize the base_url the same way the rig-core client does so the + // streaming path hits the exact same endpoint. + let streaming_base_url = if config.base_url.is_empty() { + String::new() + } else { + normalize_openai_base_url(&config.base_url) + }; + let streaming = openai_compat_stream::OpenAiCompatStreamingProvider::new( + Arc::new(adapter), + api_key, + streaming_base_url, + config.model.clone(), + config.provider_id.clone(), + extra_headers_vec, + unsupported, + ) + .map_err(|e| LlmError::RequestFailed { + provider: config.provider_id.clone(), + reason: format!("Failed to build streaming HTTP client: {e}"), + })?; + Ok(Arc::new(streaming)) } fn create_anthropic_from_registry( diff --git a/src/llm/openai_compat_stream.rs b/src/llm/openai_compat_stream.rs new file mode 100644 index 0000000000..3799b6b524 --- /dev/null +++ b/src/llm/openai_compat_stream.rs @@ -0,0 +1,588 @@ +//! Streaming-capable wrapper for OpenAI-compatible Chat Completions providers. +//! +//! Wraps an existing [`LlmProvider`] (typically a [`RigAdapter`]) and overrides +//! [`complete_stream`] / [`complete_with_tools_stream`] with real SSE streaming +//! via a direct HTTP POST to the provider's `/v1/chat/completions` endpoint. +//! All non-streaming methods are forwarded to the inner provider unchanged. +//! +//! This is used for registry providers with protocol `OpenAiCompletions` +//! (OpenRouter, Groq, NVIDIA NIM, etc.) where the upstream endpoint supports +//! the standard `"stream": true` / SSE delta format. + +use std::collections::{BTreeMap, HashSet}; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use eventsource_stream::Eventsource; +use futures::StreamExt; +use rust_decimal::Decimal; + +use crate::llm::error::LlmError; +use crate::llm::provider::{ + sanitize_tool_messages, ChatMessage, CompletionRequest, CompletionResponse, FinishReason, + LlmProvider, Role, ToolCall, ToolCompletionRequest, ToolCompletionResponse, ToolDefinition, +}; + +/// Wraps any [`LlmProvider`] backed by an OpenAI-compatible endpoint and adds +/// real token-level SSE streaming. +/// +/// Non-streaming calls (`complete`, `complete_with_tools`) are delegated to +/// the inner provider. Streaming calls bypass the inner provider and POST +/// directly to `base_url/chat/completions` with `"stream": true`, then parse +/// the OpenAI SSE delta protocol. +pub struct OpenAiCompatStreamingProvider { + inner: Arc, + client: reqwest::Client, + base_url: String, + api_key: String, + model_name: String, + /// Human-readable provider identifier used in error attribution (e.g. "openrouter", "groq"). + provider_id: String, + /// Raw (key, value) pairs sent as additional HTTP headers on every request. + extra_headers: Vec<(String, String)>, + /// Parameter names that this provider does not accept (e.g. `"temperature"`). + unsupported_params: HashSet, +} + +impl OpenAiCompatStreamingProvider { + pub fn new( + inner: Arc, + api_key: String, + base_url: String, + model_name: String, + provider_id: String, + extra_headers: Vec<(String, String)>, + unsupported_params: HashSet, + ) -> Result { + // `connect_timeout` bounds the TCP handshake; `timeout` bounds the + // total duration of a single streaming request (including reading + // the full SSE stream) so a hung upstream cannot leak tasks forever. + let client = reqwest::Client::builder() + .connect_timeout(Duration::from_secs(30)) + .timeout(Duration::from_secs(600)) + .build()?; + Ok(Self { + inner, + client, + base_url, + api_key, + model_name, + provider_id, + extra_headers, + unsupported_params, + }) + } + + fn completions_url(&self) -> String { + // Empty base_url → OpenAI default (matches rig-core behavior). + // Every provider's base_url already includes the API version prefix + // (e.g. `/v1`, `/api/v1`, `/v1beta/openai`), so just append the path. + let base = self.base_url.trim_end_matches('/'); + let base = if base.is_empty() { + "https://api.openai.com/v1" + } else { + base + }; + format!("{}/chat/completions", base) + } + + /// POST `body` (with `"stream": true` already set) to the completions + /// endpoint, parse the SSE delta stream, and return the accumulated result. + async fn stream_request( + &self, + body: serde_json::Value, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + let url = self.completions_url(); + + let mut builder = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream"); + + for (k, v) in &self.extra_headers { + builder = builder.header(k.as_str(), v.as_str()); + } + + let response = builder.json(&body).send().await.map_err(|e| { + LlmError::RequestFailed { + provider: self.provider_id.clone(), + reason: e.to_string(), + } + })?; + + let status = response.status(); + if !status.is_success() { + let code = status.as_u16(); + let retry_after = Some(crate::llm::retry::parse_retry_after( + response.headers().get("retry-after"), + )); + let text = response + .text() + .await + .unwrap_or_else(|e| format!("", e)); + let truncated = crate::agent::truncate_for_preview(&text, 512); + let lower = text.to_ascii_lowercase(); + + return Err(match code { + 401 | 403 => LlmError::AuthFailed { + provider: self.provider_id.clone(), + }, + 429 => LlmError::RateLimited { + provider: self.provider_id.clone(), + retry_after, + }, + 413 => { + let (used, limit) = + crate::llm::rig_adapter::parse_token_counts(&lower); + LlmError::ContextLengthExceeded { used, limit } + } + 400 => { + const CONTEXT_PATTERNS: &[&str] = &[ + "context_length_exceeded", + "maximum context length", + "too many tokens", + "payload too large", + ]; + if CONTEXT_PATTERNS.iter().any(|p| lower.contains(p)) { + let (used, limit) = + crate::llm::rig_adapter::parse_token_counts(&lower); + LlmError::ContextLengthExceeded { used, limit } + } else { + LlmError::RequestFailed { + provider: self.provider_id.clone(), + reason: format!("HTTP 400: {}", truncated), + } + } + } + 500..=599 => { + tracing::debug!( + provider = %self.provider_id, + status = code, + body_preview = truncated.as_str(), + "openai_compat streaming upstream 5xx response", + ); + LlmError::BadGateway { + provider: self.provider_id.clone(), + status: code, + retry_after, + } + } + _ => LlmError::RequestFailed { + provider: self.provider_id.clone(), + reason: format!("HTTP {}: {}", status, truncated), + }, + }); + } + + let mut result = OaiStreamResult::default(); + // BTreeMap keyed by tool_call index — OpenAI streams tool_call arguments + // as incremental string deltas that must be concatenated in order. + let mut tool_acc: BTreeMap = BTreeMap::new(); + + let stream = response + .bytes_stream() + .map(|chunk| chunk.map_err(|e| e.to_string())); + let mut event_stream = stream.eventsource(); + + while let Some(event) = event_stream.next().await { + let event = event.map_err(|e| LlmError::RequestFailed { + provider: self.provider_id.clone(), + reason: format!("SSE stream error: {}", e), + })?; + + let data = event.data.trim(); + if data == "[DONE]" { + break; + } + if data.is_empty() { + continue; + } + + let parsed: serde_json::Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(_) => continue, + }; + + if let Some(choices) = parsed.get("choices").and_then(|c| c.as_array()) + && let Some(choice) = choices.first() + { + if let Some(fr) = choice.get("finish_reason").and_then(|v| v.as_str()) { + result.finish_reason = match fr { + "stop" => FinishReason::Stop, + "length" => FinishReason::Length, + "tool_calls" => FinishReason::ToolUse, + "content_filter" => FinishReason::ContentFilter, + _ => result.finish_reason, + }; + } + + if let Some(delta) = choice.get("delta") { + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) + && !content.is_empty() + { + result.content.push_str(content); + on_chunk(content.to_string()); + } + + if let Some(tcs) = delta.get("tool_calls").and_then(|tc| tc.as_array()) { + for tc in tcs { + let idx = tc + .get("index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + let entry = tool_acc.entry(idx).or_insert_with(|| PartialTool { + index: idx, + ..Default::default() + }); + if let Some(id) = tc.get("id").and_then(|v| v.as_str()) + && !id.is_empty() + { + entry.id = id.to_string(); + } + if let Some(func) = tc.get("function") { + if let Some(name) = + func.get("name").and_then(|v| v.as_str()) + && !name.is_empty() + { + entry.name = name.to_string(); + } + if let Some(args) = + func.get("arguments").and_then(|v| v.as_str()) + { + entry.arguments.push_str(args); + } + } + } + } + } + } + + // Usage is typically in the last chunk when stream_options.include_usage is set. + if let Some(usage) = parsed.get("usage") { + result.input_tokens = saturate_u32( + usage + .get("prompt_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + ); + result.output_tokens = saturate_u32( + usage + .get("completion_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + ); + } + } + + result.tool_calls = tool_acc + .into_values() + .filter(|p| !p.name.is_empty()) + .map(|p| { + // Prefer parsed JSON; on parse failure preserve the raw string + // (wrapped as JSON string) so the downstream tool executor can + // surface the actual malformed payload instead of a silent {}. + let arguments = match serde_json::from_str::(&p.arguments) { + Ok(v) => v, + Err(e) => { + tracing::warn!( + tool = %p.name, + error = %e, + raw_len = p.arguments.len(), + "Failed to parse streamed tool_call arguments as JSON; preserving raw text", + ); + serde_json::Value::String(p.arguments.clone()) + } + }; + ToolCall { + id: crate::llm::rig_adapter::normalize_tool_call_id_for_streaming( + &p.id, + p.index as usize, + ), + name: p.name, + arguments, + reasoning: None, + } + }) + .collect(); + + Ok(result) + } +} + +// ── Internal helpers ───────────────────────────────────────────────────────── + +#[derive(Debug, Default)] +struct PartialTool { + index: u32, + id: String, + name: String, + arguments: String, +} + +#[derive(Debug)] +struct OaiStreamResult { + content: String, + tool_calls: Vec, + finish_reason: FinishReason, + input_tokens: u32, + output_tokens: u32, +} + +impl Default for OaiStreamResult { + fn default() -> Self { + Self { + content: String::new(), + tool_calls: Vec::new(), + finish_reason: FinishReason::Unknown, + input_tokens: 0, + output_tokens: 0, + } + } +} + +fn saturate_u32(v: u64) -> u32 { + v.min(u32::MAX as u64) as u32 +} + +/// Serialize IronClaw [`ChatMessage`]s into OpenAI Chat Completions JSON format. +fn messages_to_json(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|msg| { + let role = match msg.role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + }; + + // Multimodal: serialize content as an array of parts; text-only: plain string. + // Assistant messages with tool_calls and empty text use null content. + let content: serde_json::Value = if !msg.content_parts.is_empty() { + let mut parts = + vec![serde_json::json!({"type": "text", "text": msg.content})]; + for p in &msg.content_parts { + match serde_json::to_value(p) { + Ok(v) => parts.push(v), + Err(e) => tracing::warn!( + role = %role, + error = %e, + "Failed to serialize content part; skipping", + ), + } + } + serde_json::Value::Array(parts) + } else if role == "assistant" + && msg.tool_calls.is_some() + && msg.content.is_empty() + { + serde_json::Value::Null + } else { + serde_json::Value::String(msg.content.clone()) + }; + + let mut obj = serde_json::json!({"role": role, "content": content}); + + if let Some(id) = &msg.tool_call_id { + obj["tool_call_id"] = serde_json::json!(id); + } + if let Some(name) = &msg.name { + obj["name"] = serde_json::json!(name); + } + if let Some(tcs) = &msg.tool_calls { + let arr: Vec = tcs + .iter() + .map(|tc| { + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tc.name, + "arguments": tc.arguments.to_string(), + }, + }) + }) + .collect(); + obj["tool_calls"] = serde_json::Value::Array(arr); + } + + obj + }) + .collect() +} + +/// Serialize IronClaw [`ToolDefinition`]s into OpenAI Chat Completions JSON format. +/// +/// Schemas are run through [`normalize_schema_strict`] so top-level +/// `oneOf`/`anyOf`/`allOf`/`enum`/`not` (which OpenAI rejects with +/// `invalid_function_parameters`) are flattened into a permissive object +/// envelope. The non-streaming rig-based path normalizes via the same helper +/// inside `RigAdapter::convert_tools`; this keeps the streaming path in sync. +fn tools_to_json(tools: &[ToolDefinition]) -> Vec { + tools + .iter() + .map(|t| { + let mut description = t.description.clone(); + let parameters = + crate::llm::rig_adapter::normalize_schema_strict(&t.parameters, &mut description); + serde_json::json!({ + "type": "function", + "function": { + "name": t.name, + "description": description, + "parameters": parameters, + }, + }) + }) + .collect() +} + +// ── LlmProvider impl ───────────────────────────────────────────────────────── + +#[async_trait] +impl LlmProvider for OpenAiCompatStreamingProvider { + fn model_name(&self) -> &str { + &self.model_name + } + + fn cost_per_token(&self) -> (Decimal, Decimal) { + self.inner.cost_per_token() + } + + async fn complete(&self, request: CompletionRequest) -> Result { + self.inner.complete(request).await + } + + async fn complete_with_tools( + &self, + request: ToolCompletionRequest, + ) -> Result { + self.inner.complete_with_tools(request).await + } + + async fn complete_stream( + &self, + mut req: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + let model = req + .take_model_override() + .unwrap_or_else(|| self.model_name.clone()); + // Match RigAdapter behavior: rewrite orphaned tool_result messages as + // user messages so OpenAI-compatible endpoints do not reject the + // request with 400 "messages with role 'tool' must be a response to + // a preceding message with 'tool_calls'". + sanitize_tool_messages(&mut req.messages); + let messages = messages_to_json(&req.messages); + + let mut body = serde_json::json!({ + "model": model, + "messages": messages, + "stream": true, + "stream_options": {"include_usage": true}, + }); + + if !self.unsupported_params.contains("temperature") { + if let Some(t) = req.temperature { + body["temperature"] = serde_json::json!(t); + } + } + if !self.unsupported_params.contains("max_tokens") { + if let Some(mt) = req.max_tokens { + body["max_tokens"] = serde_json::json!(mt); + } + } + if !self.unsupported_params.contains("stop_sequences") + && let Some(stop) = req.stop_sequences + && !stop.is_empty() + { + body["stop"] = serde_json::json!(stop); + } + + let result = self.stream_request(body, on_chunk).await?; + + Ok(CompletionResponse { + content: result.content, + finish_reason: result.finish_reason, + input_tokens: result.input_tokens, + output_tokens: result.output_tokens, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }) + } + + async fn complete_with_tools_stream( + &self, + mut req: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + let model = req + .take_model_override() + .unwrap_or_else(|| self.model_name.clone()); + sanitize_tool_messages(&mut req.messages); + let messages = messages_to_json(&req.messages); + let tools = tools_to_json(&req.tools); + + let mut body = serde_json::json!({ + "model": model, + "messages": messages, + "tools": tools, + "stream": true, + "stream_options": {"include_usage": true}, + }); + + if let Some(tc) = req.tool_choice { + body["tool_choice"] = serde_json::json!(tc); + } + if !self.unsupported_params.contains("temperature") { + if let Some(t) = req.temperature { + body["temperature"] = serde_json::json!(t); + } + } + if !self.unsupported_params.contains("max_tokens") { + if let Some(mt) = req.max_tokens { + body["max_tokens"] = serde_json::json!(mt); + } + } + if !self.unsupported_params.contains("stop_sequences") + && let Some(stop) = req.stop_sequences + && !stop.is_empty() + { + body["stop"] = serde_json::json!(stop); + } + + let result = self.stream_request(body, on_chunk).await?; + + let content = if !result.content.is_empty() { + Some(result.content) + } else { + None + }; + + Ok(ToolCompletionResponse { + content, + tool_calls: result.tool_calls, + finish_reason: result.finish_reason, + input_tokens: result.input_tokens, + output_tokens: result.output_tokens, + cache_read_input_tokens: 0, + cache_creation_input_tokens: 0, + }) + } + + async fn list_models(&self) -> Result, LlmError> { + self.inner.list_models().await + } + + fn active_model_name(&self) -> String { + self.inner.active_model_name() + } + + fn set_model(&self, model: &str) -> Result<(), LlmError> { + self.inner.set_model(model) + } +} diff --git a/src/llm/provider.rs b/src/llm/provider.rs index 3daffdc3a3..aef9ccafbe 100644 --- a/src/llm/provider.rs +++ b/src/llm/provider.rs @@ -421,6 +421,34 @@ pub trait LlmProvider: Send + Sync { request: ToolCompletionRequest, ) -> Result; + /// Stream a chat completion, calling `on_chunk` for each text delta. + /// Default implementation falls back to a single-chunk (non-streaming) call. + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + let resp = self.complete(request).await?; + on_chunk(resp.content.clone()); + Ok(resp) + } + + /// Stream a tool-enabled completion, calling `on_chunk` for each text delta. + /// Default implementation falls back to a single-chunk (non-streaming) call. + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + let resp = self.complete_with_tools(request).await?; + if let Some(ref content) = resp.content { + if !content.is_empty() { + on_chunk(content.clone()); + } + } + Ok(resp) + } + /// List available models from the provider. /// Default implementation returns empty list. async fn list_models(&self) -> Result, LlmError> { diff --git a/src/llm/reasoning.rs b/src/llm/reasoning.rs index 8b54477e3d..f712c058df 100644 --- a/src/llm/reasoning.rs +++ b/src/llm/reasoning.rs @@ -266,6 +266,12 @@ pub struct ReasoningContext { /// batch failed. Used by the duplicate tool call tracker in the agentic loop. /// Reset to `false` at the start of each iteration. pub last_tool_batch_all_failed: bool, + /// When set, each text token/chunk from streaming LLM calls is sent to this + /// channel so callers can forward it to the client in real time. + /// + /// Typically backed by a bounded channel to cap memory usage; chunks that + /// cannot be queued may be dropped when the buffer is full. + pub chunk_sender: Option>, } impl ReasoningContext { @@ -282,6 +288,7 @@ impl ReasoningContext { model_override: None, temperature: None, last_tool_batch_all_failed: false, + chunk_sender: None, } } @@ -327,6 +334,15 @@ impl ReasoningContext { self.temperature = Some(temperature); self } + + /// Set a chunk sender for real-time token streaming to the client. + pub fn with_chunk_sender( + mut self, + sender: tokio::sync::mpsc::Sender, + ) -> Self { + self.chunk_sender = Some(sender); + self + } } impl Default for ReasoningContext { @@ -805,7 +821,19 @@ Respond in JSON format: request.model = Some(model.clone()); } - let response = self.llm.complete_with_tools(request).await?; + let response = if let Some(ref sender) = context.chunk_sender { + let sender = sender.clone(); + let mut on_chunk = move |chunk: String| { + // Non-blocking send: drop on full/closed channel so the + // sync callback never blocks the streaming task. + if let Err(e) = sender.try_send(chunk) { + tracing::trace!(error = %e, "stream chunk dropped (channel full/closed)"); + } + }; + self.llm.complete_with_tools_stream(request, &mut on_chunk).await? + } else { + self.llm.complete_with_tools(request).await? + }; let usage = TokenUsage { input_tokens: response.input_tokens, output_tokens: response.output_tokens, @@ -921,7 +949,17 @@ Respond in JSON format: request.model = Some(model.clone()); } - let response = self.llm.complete(request).await?; + let response = if let Some(ref sender) = context.chunk_sender { + let sender = sender.clone(); + let mut on_chunk = move |chunk: String| { + if let Err(e) = sender.try_send(chunk) { + tracing::trace!(error = %e, "stream chunk dropped (channel full/closed)"); + } + }; + self.llm.complete_stream(request, &mut on_chunk).await? + } else { + self.llm.complete(request).await? + }; let pre_truncated = truncate_at_tool_tags(&response.content); let cleaned = clean_response(&pre_truncated); let metadata = if cleaned.trim().is_empty() { diff --git a/src/llm/recording.rs b/src/llm/recording.rs index d01ee2f591..18b2b81611 100644 --- a/src/llm/recording.rs +++ b/src/llm/recording.rs @@ -1115,6 +1115,23 @@ impl LlmProvider for RecordingLlm { Ok(response) } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + // Streaming bypasses recording (chunks can't be replayed as a sequence). + self.inner.complete_stream(request, on_chunk).await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.inner.complete_with_tools_stream(request, on_chunk).await + } + async fn list_models(&self) -> Result, LlmError> { self.inner.list_models().await } diff --git a/src/llm/response_cache.rs b/src/llm/response_cache.rs index d7746f606b..d60be87cda 100644 --- a/src/llm/response_cache.rs +++ b/src/llm/response_cache.rs @@ -275,6 +275,23 @@ impl LlmProvider for CachedProvider { self.inner.complete_with_tools(request).await } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + // Bypass cache for streaming: real-time token delivery takes priority. + self.inner.complete_stream(request, on_chunk).await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.inner.complete_with_tools_stream(request, on_chunk).await + } + async fn list_models(&self) -> Result, LlmError> { self.inner.list_models().await } diff --git a/src/llm/retry.rs b/src/llm/retry.rs index cc4af3af6f..c83daed330 100644 --- a/src/llm/retry.rs +++ b/src/llm/retry.rs @@ -253,6 +253,22 @@ impl LlmProvider for RetryProvider { .await } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.inner.complete_stream(request, on_chunk).await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.inner.complete_with_tools_stream(request, on_chunk).await + } + async fn list_models(&self) -> Result, LlmError> { self.inner.list_models().await } diff --git a/src/llm/rig_adapter.rs b/src/llm/rig_adapter.rs index 13f9e24ec0..8cd380e4f8 100644 --- a/src/llm/rig_adapter.rs +++ b/src/llm/rig_adapter.rs @@ -802,6 +802,16 @@ fn normalized_tool_call_id(raw: Option<&str>, seed: usize) -> String { super::provider::generate_tool_call_id(seed, 0) } +/// Normalize a streamed tool-call ID into the `[a-zA-Z0-9]{9}` shape that +/// OpenAI-compatible backends require. Delegates to [`normalized_tool_call_id`]. +/// +/// `index` is the tool_call's position in the stream and is used as a +/// deterministic seed when the upstream ID is absent or non-conforming. +pub(crate) fn normalize_tool_call_id_for_streaming(raw: &str, index: usize) -> String { + let raw_opt = if raw.is_empty() { None } else { Some(raw) }; + normalized_tool_call_id(raw_opt, index) +} + /// Convert IronClaw tool definitions to rig-core format. /// /// Applies `normalize_schema_strict` at the boundary, which both diff --git a/src/llm/runtime.rs b/src/llm/runtime.rs index 8d417b1265..57bf57c7b0 100644 --- a/src/llm/runtime.rs +++ b/src/llm/runtime.rs @@ -200,6 +200,22 @@ impl LlmProvider for SwappableLlmProvider { self.current().complete_with_tools(request).await } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.current().complete_stream(request, on_chunk).await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.current().complete_with_tools_stream(request, on_chunk).await + } + async fn list_models(&self) -> Result, LlmError> { self.current().list_models().await } diff --git a/src/llm/smart_routing.rs b/src/llm/smart_routing.rs index 463bf5a0ae..8df69749d6 100644 --- a/src/llm/smart_routing.rs +++ b/src/llm/smart_routing.rs @@ -985,6 +985,22 @@ impl LlmProvider for SmartRoutingProvider { self.primary.complete_with_tools(request).await } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.primary.complete_stream(request, on_chunk).await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.primary.complete_with_tools_stream(request, on_chunk).await + } + async fn list_models(&self) -> Result, LlmError> { self.primary.list_models().await } diff --git a/src/llm/token_refreshing.rs b/src/llm/token_refreshing.rs index c39ad3243c..945cca6645 100644 --- a/src/llm/token_refreshing.rs +++ b/src/llm/token_refreshing.rs @@ -109,6 +109,28 @@ impl LlmProvider for TokenRefreshingProvider { } } + async fn complete_stream( + &self, + request: CompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.ensure_fresh_token().await; + // Streaming requests are not retried on auth failure (the partial + // chunks already sent can't be un-sent); delegate straight through. + self.inner.complete_stream(request, on_chunk).await + } + + async fn complete_with_tools_stream( + &self, + request: ToolCompletionRequest, + on_chunk: &mut (dyn FnMut(String) + Send), + ) -> Result { + self.ensure_fresh_token().await; + self.inner + .complete_with_tools_stream(request, on_chunk) + .await + } + async fn list_models(&self) -> Result, LlmError> { self.ensure_fresh_token().await; self.inner.list_models().await