diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index a1e30eabcb09..b7148f8864a5 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -1,10 +1,11 @@ use anyhow::Result; use async_trait::async_trait; +use futures::future::BoxFuture; use rmcp::model::Role; use serde_json::{json, Value}; use std::path::PathBuf; use std::process::Stdio; -use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::process::Command; use super::base::{ConfigKey, Provider, ProviderDef, ProviderMetadata, ProviderUsage, Usage}; @@ -16,7 +17,6 @@ use crate::config::{Config, GooseMode}; use crate::conversation::message::{Message, MessageContent}; use crate::model::ModelConfig; use crate::subprocess::configure_command_no_window; -use futures::future::BoxFuture; use rmcp::model::Tool; const CLAUDE_CODE_PROVIDER_NAME: &str = "claude-code"; @@ -24,12 +24,35 @@ pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "claude-sonnet-4-20250514"; pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["sonnet", "opus"]; pub const CLAUDE_CODE_DOC_URL: &str = "https://code.claude.com/docs/en/setup"; +#[derive(Debug)] +struct CliProcess { + child: tokio::process::Child, + stdin: tokio::process::ChildStdin, + reader: BufReader, + #[allow(dead_code)] + stderr_handle: tokio::task::JoinHandle, + messages_sent: usize, +} + +impl Drop for CliProcess { + fn drop(&mut self) { + let _ = self.child.start_kill(); + } +} + +/// Spawns the Claude Code CLI (`claude`) as a persistent child process using +/// `--input-format stream-json --output-format stream-json`. The CLI stays alive +/// across turns, maintaining conversation state internally. Messages are sent as +/// NDJSON on stdin with content arrays supporting text and image blocks. Responses +/// are NDJSON on stdout (`assistant` + `result` events per turn). #[derive(Debug, serde::Serialize)] pub struct ClaudeCodeProvider { command: PathBuf, model: ModelConfig, #[serde(skip)] name: String, + #[serde(skip)] + cli_process: tokio::sync::OnceCell>, } impl ClaudeCodeProvider { @@ -42,76 +65,60 @@ impl ClaudeCodeProvider { command: resolved_command, model, name: CLAUDE_CODE_PROVIDER_NAME.to_string(), + cli_process: tokio::sync::OnceCell::new(), }) } - /// Convert goose messages to the format expected by claude CLI - fn messages_to_claude_format(&self, _system: &str, messages: &[Message]) -> Result { - let mut claude_messages = Vec::new(); - + /// Build Anthropic content blocks from goose messages, supporting text and images. + fn messages_to_content_blocks(&self, messages: &[Message]) -> Vec { + let mut blocks: Vec = Vec::new(); for message in messages.iter().filter(|m| m.is_agent_visible()) { - let role = match message.role { - Role::User => "user", - Role::Assistant => "assistant", + let prefix = match message.role { + Role::User => "Human: ", + Role::Assistant => "Assistant: ", }; - - let mut content_parts = Vec::new(); + let mut text_parts = Vec::new(); for content in &message.content { match content { - MessageContent::Text(text_content) => { - content_parts.push(json!({ - "type": "text", - "text": text_content.text - })); + MessageContent::Text(t) => text_parts.push(t.text.clone()), + MessageContent::Image(img) => { + if !text_parts.is_empty() { + blocks.push(json!({"type":"text","text":format!("{}{}", prefix, text_parts.join("\n"))})); + text_parts.clear(); + } + blocks.push(json!({"type":"image","source":{"type":"base64","media_type":img.mime_type,"data":img.data}})); } - MessageContent::ToolRequest(tool_request) => { - if let Ok(tool_call) = &tool_request.tool_call { - content_parts.push(json!({ - "type": "tool_use", - "id": tool_request.id, - "name": tool_call.name, - "input": tool_call.arguments - })); + MessageContent::ToolRequest(req) => { + if let Ok(call) = &req.tool_call { + text_parts.push(format!("[tool_use: {} id={}]", call.name, req.id)); } } - MessageContent::ToolResponse(tool_response) => { - if let Ok(result) = &tool_response.tool_result { - // Convert tool result contents to text - let content_text = result + MessageContent::ToolResponse(resp) => { + if let Ok(result) = &resp.tool_result { + let text: String = result .content .iter() - .filter_map(|content| match &content.raw { - rmcp::model::RawContent::Text(text_content) => { - Some(text_content.text.as_str()) - } + .filter_map(|c| match &c.raw { + rmcp::model::RawContent::Text(t) => Some(t.text.as_str()), _ => None, }) .collect::>() .join("\n"); - - content_parts.push(json!({ - "type": "tool_result", - "tool_use_id": tool_response.id, - "content": content_text - })); + text_parts.push(format!("[tool_result id={}] {}", resp.id, text)); } } - _ => { - // Skip other content types for now - } + _ => {} } } - - claude_messages.push(json!({ - "role": role, - "content": content_parts - })); + if !text_parts.is_empty() { + blocks.push( + json!({"type":"text","text":format!("{}{}", prefix, text_parts.join("\n"))}), + ); + } } - - Ok(json!(claude_messages)) + blocks } - /// Parse the JSON response from claude CLI fn apply_permission_flags(cmd: &mut Command) -> Result<(), ProviderError> { let config = Config::global(); let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto); @@ -139,6 +146,7 @@ impl ClaudeCodeProvider { Ok(()) } + /// Parse NDJSON stream-json response from Claude CLI fn parse_claude_response( &self, json_lines: &[String], @@ -146,33 +154,23 @@ impl ClaudeCodeProvider { let mut all_text_content = Vec::new(); let mut usage = Usage::default(); - // Join all lines and parse as a single JSON array - let full_response = json_lines.join(""); - let json_array: Vec = serde_json::from_str(&full_response).map_err(|e| { - ProviderError::RequestFailed(format!("Failed to parse JSON response: {}", e)) - })?; - - for parsed in json_array { - if let Some(msg_type) = parsed.get("type").and_then(|t| t.as_str()) { - match msg_type { - "assistant" => { + for line in json_lines { + if let Ok(parsed) = serde_json::from_str::(line) { + match parsed.get("type").and_then(|t| t.as_str()) { + Some("assistant") => { if let Some(message) = parsed.get("message") { // Extract text content from this assistant message if let Some(content) = message.get("content").and_then(|c| c.as_array()) { for item in content { - if let Some(content_type) = - item.get("type").and_then(|t| t.as_str()) - { - if content_type == "text" { - if let Some(text) = - item.get("text").and_then(|t| t.as_str()) - { - all_text_content.push(text.to_string()); - } + if item.get("type").and_then(|t| t.as_str()) == Some("text") { + if let Some(text) = + item.get("text").and_then(|t| t.as_str()) + { + all_text_content.push(text.to_string()); } - // Skip tool_use - those are claude CLI's internal tools } + // Skip tool_use - those are claude CLI's internal tools } } @@ -187,7 +185,6 @@ impl ClaudeCodeProvider { .and_then(|v| v.as_i64()) .map(|v| v as i32); - // Calculate total if not provided if usage.total_tokens.is_none() { if let (Some(input), Some(output)) = (usage.input_tokens, usage.output_tokens) @@ -198,7 +195,7 @@ impl ClaudeCodeProvider { } } } - "result" => { + Some("result") => { // Extract additional usage info from result if available if let Some(result_usage) = parsed.get("usage") { if usage.input_tokens.is_none() { @@ -215,7 +212,23 @@ impl ClaudeCodeProvider { } } } - _ => {} // Ignore other message types + Some("error") => { + let error_msg = parsed + .get("error") + .and_then(|e| e.as_str()) + .unwrap_or("Unknown error"); + if error_msg.contains("context") && error_msg.contains("exceeded") { + return Err(ProviderError::ContextLengthExceeded( + error_msg.to_string(), + )); + } + return Err(ProviderError::RequestFailed(format!( + "Claude CLI error: {}", + error_msg + ))); + } + Some("system") => {} // Ignore system init events + _ => {} // Ignore other event types } } } @@ -245,12 +258,6 @@ impl ClaudeCodeProvider { messages: &[Message], _tools: &[Tool], ) -> Result, ProviderError> { - let messages_json = self - .messages_to_claude_format(system, messages) - .map_err(|e| { - ProviderError::RequestFailed(format!("Failed to format messages: {}", e)) - })?; - let filtered_system = filter_extensions_from_system_prompt(system); if std::env::var("GOOSE_CLAUDE_CODE_DEBUG").is_ok() { @@ -262,57 +269,125 @@ impl ClaudeCodeProvider { filtered_system.len() ); println!("Filtered system prompt: {}", filtered_system); - println!( - "Messages JSON: {}", - serde_json::to_string_pretty(&messages_json) - .unwrap_or_else(|_| "Failed to serialize".to_string()) - ); println!("================================"); } - let mut cmd = Command::new(&self.command); - configure_command_no_window(&mut cmd); - cmd.arg("-p") - .arg(messages_json.to_string()) - .arg("--system-prompt") - .arg(&filtered_system); - - // Only pass model parameter if it's in the known models list - if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { - cmd.arg("--model").arg(&self.model.model_name); - } - - cmd.arg("--verbose").arg("--output-format").arg("json"); - - // Add permission mode based on GOOSE_MODE setting - Self::apply_permission_flags(&mut cmd)?; - - cmd.stdout(Stdio::piped()).stderr(Stdio::piped()); + // Spawn lazily on first call (OnceCell ensures exactly once) + let process_mutex = self + .cli_process + .get_or_try_init(|| async { + let mut cmd = Command::new(&self.command); + // NO -p flag — persistent mode + configure_command_no_window(&mut cmd); + cmd.arg("--input-format") + .arg("stream-json") + .arg("--output-format") + .arg("stream-json") + .arg("--verbose") + // System prompt is set once at process start. The provider + // instance is not reused across sessions with different prompts. + .arg("--system-prompt") + .arg(&filtered_system); + + // Only pass model parameter if it's in the known models list + if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { + cmd.arg("--model").arg(&self.model.model_name); + } - let mut child = cmd.spawn().map_err(|e| { - ProviderError::RequestFailed(format!( - "Failed to spawn Claude CLI command '{:?}': {}.", - self.command, e - )) + // Add permission mode based on GOOSE_MODE setting + Self::apply_permission_flags(&mut cmd)?; + + cmd.stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + let mut child = cmd.spawn().map_err(|e| { + ProviderError::RequestFailed(format!( + "Failed to spawn Claude CLI command '{:?}': {}.", + self.command, e + )) + })?; + + let stdin = child.stdin.take().ok_or_else(|| { + ProviderError::RequestFailed("Failed to capture stdin".to_string()) + })?; + let stdout = child.stdout.take().ok_or_else(|| { + ProviderError::RequestFailed("Failed to capture stdout".to_string()) + })?; + + // Drain stderr concurrently to prevent pipe buffer deadlock + let stderr = child.stderr.take(); + let stderr_handle = tokio::spawn(async move { + let mut output = String::new(); + if let Some(mut stderr) = stderr { + use tokio::io::AsyncReadExt; + let _ = stderr.read_to_string(&mut output).await; + } + output + }); + + Ok::<_, ProviderError>(tokio::sync::Mutex::new(CliProcess { + child, + stdin, + reader: BufReader::new(stdout), + stderr_handle, + messages_sent: 0, + })) + }) + .await?; + + let mut process = process_mutex.lock().await; + + // Build content from new messages only (skip already-sent ones). + // If messages is shorter than messages_sent, the caller started a fresh + // conversation on the same provider instance — send everything. + let new_messages = if process.messages_sent > 0 && process.messages_sent < messages.len() { + &messages[process.messages_sent..] + } else { + messages + }; + let new_blocks = self.messages_to_content_blocks(new_messages); + + // Write NDJSON line to stdin + let ndjson_line = build_stream_json_input(&new_blocks); + process + .stdin + .write_all(ndjson_line.as_bytes()) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write to stdin: {}", e)) + })?; + process.stdin.write_all(b"\n").await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write newline to stdin: {}", e)) })?; - let stdout = child - .stdout - .take() - .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; - - let mut reader = BufReader::new(stdout); + // Read lines until we see a "result" event let mut lines = Vec::new(); let mut line = String::new(); loop { line.clear(); - match reader.read_line(&mut line).await { - Ok(0) => break, // EOF + match process.reader.read_line(&mut line).await { + Ok(0) => { + // EOF means the process died + return Err(ProviderError::RequestFailed( + "Claude CLI process terminated unexpectedly".to_string(), + )); + } Ok(_) => { let trimmed = line.trim(); - if !trimmed.is_empty() { - lines.push(trimmed.to_string()); + if trimmed.is_empty() { + continue; + } + lines.push(trimmed.to_string()); + + // Check if this is a result event (end of turn) + if let Ok(parsed) = serde_json::from_str::(trimmed) { + match parsed.get("type").and_then(|t| t.as_str()) { + Some("result") => break, + Some("error") => break, + _ => {} + } } } Err(e) => { @@ -324,16 +399,8 @@ impl ClaudeCodeProvider { } } - let exit_status = child.wait().await.map_err(|e| { - ProviderError::RequestFailed(format!("Failed to wait for command: {}", e)) - })?; - - if !exit_status.success() { - return Err(ProviderError::RequestFailed(format!( - "Command failed with exit code: {:?}", - exit_status.code() - ))); - } + // Update messages_sent for next turn + process.messages_sent = messages.len(); tracing::debug!("Command executed successfully, got {} lines", lines.len()); for (i, line) in lines.iter().enumerate() { @@ -389,6 +456,12 @@ impl ClaudeCodeProvider { } } +fn build_stream_json_input(content_blocks: &[Value]) -> String { + let msg = json!({"type":"user","message":{"role":"user","content":content_blocks}}); + serde_json::to_string(&msg).expect("serializing JSON content blocks cannot fail") +} + +#[async_trait] impl ProviderDef for ClaudeCodeProvider { type Provider = Self; @@ -463,3 +536,210 @@ impl Provider for ClaudeCodeProvider { )) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use test_case::test_case; + + /// (role, text, optional (image_data, mime_type)) + type MsgSpec<'a> = (&'a str, &'a str, Option<(&'a str, &'a str)>); + + fn build_messages(specs: &[MsgSpec]) -> Vec { + specs + .iter() + .map(|(role, text, image)| { + let role = if *role == "user" { + Role::User + } else { + Role::Assistant + }; + let mut msg = Message::new(role, 0, vec![]); + if !text.is_empty() { + msg = Message::new(msg.role.clone(), 0, vec![MessageContent::text(*text)]); + } + if let Some((data, mime)) = image { + msg.content.push(MessageContent::image(*data, *mime)); + } + msg + }) + .collect() + } + + #[test_case( + &[], + &[] + ; "empty" + )] + #[test_case( + &[("user", "Hello", None)], + &[json!({"type":"text","text":"Human: Hello"})] + ; "single_user" + )] + #[test_case( + &[("user", "Hello", None), ("assistant", "Hi there!", None)], + &[json!({"type":"text","text":"Human: Hello"}), json!({"type":"text","text":"Assistant: Hi there!"})] + ; "user_and_assistant" + )] + #[test_case( + &[("user", "Describe this", Some(("base64data", "image/png")))], + &[json!({"type":"text","text":"Human: Describe this"}), + json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"base64data"}})] + ; "user_with_image" + )] + #[test_case( + &[("user", "", Some(("iVBORw0KGgo", "image/png")))], + &[json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"iVBORw0KGgo"}})] + ; "image_only" + )] + fn test_messages_to_content_blocks(pairs: &[MsgSpec], expected: &[Value]) { + let provider = make_provider(); + let messages = build_messages(pairs); + let blocks = provider.messages_to_content_blocks(&messages); + assert_eq!(blocks, expected); + } + + #[test] + fn test_messages_to_content_blocks_tool_request() { + use rmcp::model::CallToolRequestParams; + let provider = make_provider(); + let tool_call = Ok(CallToolRequestParams { + name: "developer__shell".into(), + arguments: Some(serde_json::from_value(json!({"cmd": "ls"})).unwrap()), + meta: None, + task: None, + }); + let msg = Message::new( + Role::Assistant, + 0, + vec![MessageContent::tool_request("call_123", tool_call)], + ); + let blocks = provider.messages_to_content_blocks(&[msg]); + assert_eq!( + blocks, + vec![ + json!({"type":"text","text":"Assistant: [tool_use: developer__shell id=call_123]"}) + ] + ); + } + + #[test] + fn test_messages_to_content_blocks_tool_response() { + use rmcp::model::{CallToolResult, Content}; + let provider = make_provider(); + let result = CallToolResult { + content: vec![Content::text("file1.txt\nfile2.txt")], + is_error: None, + structured_content: None, + meta: None, + }; + let msg = Message::new( + Role::User, + 0, + vec![MessageContent::tool_response("call_123", Ok(result))], + ); + let blocks = provider.messages_to_content_blocks(&[msg]); + assert_eq!( + blocks, + vec![ + json!({"type":"text","text":"Human: [tool_result id=call_123] file1.txt\nfile2.txt"}) + ] + ); + } + + #[test_case( + &[json!({"type":"text","text":"Hello"})], + json!({"type":"user","message":{"role":"user","content":[{"type":"text","text":"Hello"}]}}) + ; "text_block" + )] + #[test_case( + &[json!({"type":"text","text":"Look"}), json!({"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}})], + json!({"type":"user","message":{"role":"user","content":[{"type":"text","text":"Look"},{"type":"image","source":{"type":"base64","media_type":"image/png","data":"abc"}}]}}) + ; "text_and_image_blocks" + )] + fn test_build_stream_json_input(blocks: &[Value], expected: Value) { + let line = build_stream_json_input(blocks); + let parsed: Value = serde_json::from_str(&line).unwrap(); + assert_eq!(parsed, expected); + } + + #[test_case( + &[ + r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"The answer is 2."}],"usage":{"input_tokens":100,"output_tokens":20}}}"#, + r#"{"type":"result","subtype":"success","result":"The answer is 2.","session_id":"abc"}"#, + ], + "The answer is 2.", + Some(100), Some(20) + ; "assistant_with_usage" + )] + #[test_case( + &[ + r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}}"#, + ], + "First\n\nSecond", + None, None + ; "multiple_text_blocks" + )] + #[test_case( + &[ + r#"{"type":"system","subtype":"init","session_id":"abc"}"#, + r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello"}]}}"#, + r#"{"type":"result","subtype":"success","result":"Hello","session_id":"abc"}"#, + ], + "Hello", + None, None + ; "system_init_filtered" + )] + fn test_parse_claude_response_ok( + lines: &[&str], + expected_text: &str, + expected_input: Option, + expected_output: Option, + ) { + let provider = make_provider(); + let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); + let (message, usage) = provider.parse_claude_response(&lines).unwrap(); + assert_eq!(message.role, Role::Assistant); + if let MessageContent::Text(t) = &message.content[0] { + assert_eq!(t.text, expected_text); + } else { + panic!("expected text content"); + } + assert_eq!(usage.input_tokens, expected_input); + assert_eq!(usage.output_tokens, expected_output); + } + + #[test_case( + &[], + ProviderError::RequestFailed("No text content found in response".into()) + ; "empty_lines" + )] + #[test_case( + &[r#"{"type":"error","error":"context window exceeded"}"#], + ProviderError::ContextLengthExceeded("context window exceeded".into()) + ; "context_length" + )] + #[test_case( + &[r#"{"type":"error","error":"Model not supported"}"#], + ProviderError::RequestFailed("Claude CLI error: Model not supported".into()) + ; "generic_error" + )] + fn test_parse_claude_response_err(lines: &[&str], expected: ProviderError) { + let provider = make_provider(); + let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); + assert_eq!( + provider.parse_claude_response(&lines).unwrap_err(), + expected + ); + } + + fn make_provider() -> ClaudeCodeProvider { + ClaudeCodeProvider { + command: PathBuf::from("claude"), + model: ModelConfig::new("sonnet").unwrap(), + name: "claude-code".to_string(), + cli_process: tokio::sync::OnceCell::new(), + } + } +}