diff --git a/Cargo.lock b/Cargo.lock index 27a8e5c9f437..82bb6d7d3754 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4187,6 +4187,7 @@ dependencies = [ "base64 0.22.1", "blake3", "byteorder", + "bytes", "candle-core", "candle-nn", "candle-transformers", @@ -4203,6 +4204,10 @@ dependencies = [ "fs2", "futures", "goose-test-support", + "http 1.4.0", + "http-body-util", + "hyper 1.8.1", + "hyper-util", "ignore", "include_dir", "indexmap 2.13.0", @@ -4244,6 +4249,7 @@ dependencies = [ "shell-words", "shellexpand", "sqlx", + "sse-stream", "strum 0.27.2", "symphonia", "sys-info", diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 730c1fc03b60..e6ebf60f3853 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -101,6 +101,7 @@ fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result anyhow::Result<()> { envs: Envs::new(envs), env_keys, headers, + socket: None, description, timeout: Some(timeout), bundled: None, diff --git a/crates/goose-cli/src/recipes/secret_discovery.rs b/crates/goose-cli/src/recipes/secret_discovery.rs index 56bba3212677..894dacf60771 100644 --- a/crates/goose-cli/src/recipes/secret_discovery.rs +++ b/crates/goose-cli/src/recipes/secret_discovery.rs @@ -169,6 +169,7 @@ mod tests { bundled: None, available_tools: Vec::new(), headers: HashMap::new(), + socket: None, }, ExtensionConfig::Stdio { name: "slack-mcp".to_string(), @@ -265,6 +266,7 @@ mod tests { bundled: None, available_tools: Vec::new(), headers: HashMap::new(), + socket: None, }, ExtensionConfig::Stdio { name: "service-b".to_string(), @@ -324,6 +326,7 @@ mod tests { bundled: None, available_tools: Vec::new(), headers: HashMap::new(), + socket: None, }]), sub_recipes: Some(vec![SubRecipe { name: "child-recipe".to_string(), diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index b3969a135d11..8dc024ac229c 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -337,6 +337,7 @@ impl CliSession { envs: Envs::new(HashMap::new()), env_keys: Vec::new(), headers: HashMap::new(), + socket: None, description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), timeout: Some(timeout), bundled: None, @@ -2031,6 +2032,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::new(), + socket: None, description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), timeout: Some(300), bundled: None, @@ -2046,6 +2048,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::new(), + socket: None, description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), timeout: Some(300), bundled: None, @@ -2061,6 +2064,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::new(), + socket: None, description: goose::config::DEFAULT_EXTENSION_DESCRIPTION.to_string(), timeout: Some(300), bundled: None, diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 171158799efc..f9b23b2c110e 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -155,6 +155,15 @@ encoding_rs = "0.8.35" pastey = "0.2.1" shell-words = { workspace = true } +# Unix domain socket HTTP transport for StreamableHttp extensions +[target.'cfg(unix)'.dependencies] +hyper = { version = "1", features = ["client", "http1"] } +hyper-util = { version = "0.1", features = ["tokio"] } +http-body-util = "0.1" +sse-stream = "0.2" +bytes = { workspace = true } +http = { workspace = true } + [target.'cfg(target_os = "windows")'.dependencies] winapi = { workspace = true } diff --git a/crates/goose/src/acp/provider.rs b/crates/goose/src/acp/provider.rs index 4461351c6f9e..325736404f07 100644 --- a/crates/goose/src/acp/provider.rs +++ b/crates/goose/src/acp/provider.rs @@ -1458,6 +1458,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::from([("Authorization".into(), "Bearer ghp_xxxxxxxxxxxx".into())]), + socket: None, timeout: None, bundled: Some(false), available_tools: vec![], @@ -1511,6 +1512,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::from([("Authorization".into(), "Bearer ghp_xxxxxxxxxxxx".into())]), + socket: None, timeout: None, bundled: Some(false), available_tools: vec![], diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index e30bdb334de6..a8b7aa6dade0 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -232,6 +232,11 @@ pub enum ExtensionConfig { env_keys: Vec, #[serde(default)] headers: HashMap, + /// Unix domain socket path to route HTTP through (e.g. "@egress.sock" for Envoy sidecar). + /// When set, the physical connection goes through this socket while `uri` is used for the + /// HTTP Host header and path. Useful in K8s environments where DNS only resolves via Envoy. + #[serde(default)] + socket: Option, // NOTE: set timeout to be optional for compatibility. // However, new configurations should include this field. timeout: Option, @@ -305,6 +310,7 @@ impl ExtensionConfig { envs: Envs::default(), env_keys: Vec::new(), headers: HashMap::new(), + socket: None, description: description.into(), timeout: Some(timeout.into()), bundled: None, @@ -459,6 +465,7 @@ impl ExtensionConfig { envs, env_keys, headers, + socket, timeout, bundled, available_tools, @@ -478,6 +485,7 @@ impl ExtensionConfig { envs: Envs::new(merged), env_keys: vec![], headers, + socket, timeout, bundled, available_tools, @@ -494,9 +502,12 @@ impl std::fmt::Display for ExtensionConfig { ExtensionConfig::Sse { name, .. } => { write!(f, "SSE({}: unsupported)", name) } - ExtensionConfig::StreamableHttp { name, uri, .. } => { - write!(f, "StreamableHttp({}: {})", name, uri) - } + ExtensionConfig::StreamableHttp { + name, uri, socket, .. + } => match socket { + Some(s) => write!(f, "StreamableHttp({}: {} via {})", name, uri, s), + None => write!(f, "StreamableHttp({}: {})", name, uri), + }, ExtensionConfig::Stdio { name, cmd, args, .. } => { @@ -678,6 +689,7 @@ available_tools: [] )] .into_iter() .collect(), + socket: None, timeout: None, bundled: None, available_tools: vec![], @@ -698,6 +710,7 @@ available_tools: [] )] .into_iter() .collect(), + socket: None, timeout: None, bundled: None, available_tools: vec![], @@ -771,6 +784,7 @@ available_tools: [] )] .into_iter() .collect(), + socket: None, timeout: None, bundled: None, available_tools: vec![], @@ -788,6 +802,7 @@ available_tools: [] headers: [("Authorization".to_string(), "Bearer secret_value".to_string())] .into_iter() .collect(), + socket: None, timeout: None, bundled: None, available_tools: vec![], @@ -802,6 +817,7 @@ available_tools: [] envs: extension::Envs::default(), env_keys: vec!["MY_SECRET".into()], headers: std::collections::HashMap::new(), + socket: None, timeout: None, bundled: None, available_tools: vec![], @@ -817,6 +833,7 @@ available_tools: [] }), env_keys: vec![], headers: std::collections::HashMap::new(), + socket: None, timeout: None, bundled: None, available_tools: vec![], @@ -867,4 +884,64 @@ available_tools: [] cfg.set("MY_SECRET", &"secret_value", true).unwrap(); assert_eq!(config.resolve(&cfg).await.unwrap(), expected); } + + #[test] + fn test_deserialize_streamable_http_with_socket() { + let config: ExtensionConfig = serde_yaml::from_str( + "type: streamable_http\nname: ai-app-info\ndescription: test\nuri: http://example.com/mcp\nsocket: \"@egress.sock\"\n", + ) + .unwrap(); + if let ExtensionConfig::StreamableHttp { socket, .. } = config { + assert_eq!(socket, Some("@egress.sock".to_string())); + } else { + panic!("unexpected variant"); + } + } + + #[test] + fn test_deserialize_streamable_http_without_socket() { + let config: ExtensionConfig = serde_yaml::from_str( + "type: streamable_http\nname: ai-app-info\ndescription: test\nuri: http://example.com/mcp\n", + ) + .unwrap(); + if let ExtensionConfig::StreamableHttp { socket, .. } = config { + assert_eq!(socket, None); + } else { + panic!("unexpected variant"); + } + } + + #[test] + fn test_display_streamable_http_without_socket() { + let config = ExtensionConfig::streamable_http( + "ai-app-info", + "http://example.com/mcp", + "test", + 300u64, + ); + assert_eq!( + format!("{config}"), + "StreamableHttp(ai-app-info: http://example.com/mcp)" + ); + } + + #[test] + fn test_display_streamable_http_with_socket() { + let config = ExtensionConfig::StreamableHttp { + name: "ai-app-info".to_string(), + uri: "http://example.com/mcp".to_string(), + description: "test".to_string(), + timeout: Some(300), + headers: Default::default(), + envs: Default::default(), + env_keys: vec![], + socket: Some("@egress.sock".to_string()), + bundled: None, + available_tools: vec![], + }; + assert_eq!( + format!("{config}"), + "StreamableHttp(ai-app-info: http://example.com/mcp via @egress.sock)" + ); + } } diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 6dd055de5277..5e02f08ddbb5 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -305,14 +305,20 @@ fn extract_auth_error( ClientInitializeError::TransportError { error: DynamicTransportError { error, .. }, .. - } => error - .downcast_ref::>() - .and_then(|auth_error| match auth_error { - StreamableHttpError::AuthRequired(auth_required_error) => { - Some(auth_required_error) - } - _ => None, - }), + } => { + if let Some(StreamableHttpError::AuthRequired(e)) = + error.downcast_ref::>() + { + return Some(e); + } + #[cfg(unix)] + if let Some(StreamableHttpError::AuthRequired(e)) = error + .downcast_ref::>() + { + return Some(e); + } + None + } _ => None, }, } @@ -412,7 +418,31 @@ async fn create_streamable_http_client( client_name: String, capabilities: GooseMcpClientCapabilities, roots_dir: &std::path::Path, + socket: Option<&str>, ) -> ExtensionResult> { + #[cfg(unix)] + if let Some(socket_path) = socket { + return create_unix_socket_http_client( + uri, + timeout, + headers, + name, + socket_path, + provider, + client_name, + capabilities, + roots_dir, + ) + .await; + } + + #[cfg(not(unix))] + if socket.is_some() { + return Err(ExtensionError::ConfigError( + "Unix domain socket transport is not supported on this platform".to_string(), + )); + } + let mut default_headers = HeaderMap::new(); default_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT); @@ -489,6 +519,97 @@ async fn create_streamable_http_client( } } +#[cfg(unix)] +#[allow(clippy::too_many_arguments)] +async fn create_unix_socket_http_client( + uri: &str, + timeout: Option, + headers: &HashMap, + name: &str, + socket_path: &str, + provider: SharedProvider, + client_name: String, + capabilities: GooseMcpClientCapabilities, + roots_dir: &std::path::Path, +) -> ExtensionResult> { + use super::unix_socket_http_client::UnixSocketHttpClient; + use http::header::HeaderValue; + + #[cfg(not(target_os = "linux"))] + if socket_path.starts_with('@') { + return Err(ExtensionError::ConfigError( + "Abstract Unix sockets (@-prefixed) are only supported on Linux".to_string(), + )); + } + + let mut default_headers = std::collections::HashMap::new(); + default_headers.insert(reqwest::header::USER_AGENT, GOOSE_USER_AGENT); + for (key, value) in headers { + let header_name = HeaderName::try_from(key) + .map_err(|_| ExtensionError::ConfigError(format!("invalid header: {key}")))?; + let val: HeaderValue = value + .parse() + .map_err(|_| ExtensionError::ConfigError(format!("invalid header value: {key}")))?; + default_headers.insert(header_name, val); + } + + let retry_headers = { + let mut h = default_headers.clone(); + h.remove(&http::header::AUTHORIZATION); + h + }; + + let unix_client = UnixSocketHttpClient::new(uri, socket_path, default_headers); + let transport = StreamableHttpClientTransport::with_client( + unix_client, + StreamableHttpClientTransportConfig { + uri: uri.into(), + ..Default::default() + }, + ); + + let timeout_duration = + Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT)); + + let client_res = McpClient::connect( + transport, + timeout_duration, + provider.clone(), + client_name.clone(), + capabilities.clone(), + roots_dir.to_path_buf(), + ) + .await; + + if extract_auth_error(&client_res).is_some() { + let auth_manager = oauth_flow(&uri.to_string(), &name.to_string()) + .await + .map_err(|_| ExtensionError::SetupError("auth error".to_string()))?; + let auth_unix_client = UnixSocketHttpClient::new(uri, socket_path, retry_headers); + let auth_client = AuthClient::new(auth_unix_client, auth_manager); + let transport = StreamableHttpClientTransport::with_client( + auth_client, + StreamableHttpClientTransportConfig { + uri: uri.into(), + ..Default::default() + }, + ); + Ok(Box::new( + McpClient::connect( + transport, + timeout_duration, + provider, + client_name, + capabilities, + roots_dir.to_path_buf(), + ) + .await?, + )) + } else { + Ok(Box::new(client_res?)) + } +} + impl ExtensionManager { pub fn new( provider: SharedProvider, @@ -585,6 +706,7 @@ impl ExtensionManager { name, envs, env_keys, + socket, .. } => { let config = Config::global(); @@ -607,6 +729,7 @@ impl ExtensionManager { self.client_name.clone(), capability, &effective_working_dir, + socket.as_ref().map(|s| substitute_env_vars(s, &all_envs)).as_deref(), ) .await? } diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 1b41a743182b..6d659b10027d 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -21,6 +21,8 @@ pub(crate) mod subagent_task_config; mod tool_confirmation_router; mod tool_execution; pub mod types; +#[cfg(unix)] +pub(crate) mod unix_socket_http_client; pub mod validate_extensions; pub use agent::{Agent, AgentConfig, AgentEvent, ExtensionLoadResult, GoosePlatform}; diff --git a/crates/goose/src/agents/unix_socket_http_client.rs b/crates/goose/src/agents/unix_socket_http_client.rs new file mode 100644 index 000000000000..12333009f80b --- /dev/null +++ b/crates/goose/src/agents/unix_socket_http_client.rs @@ -0,0 +1,447 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::Bytes; +use futures::stream::BoxStream; +use futures::StreamExt; +use http::{HeaderName, HeaderValue, Method, Request, StatusCode}; +use http_body_util::{BodyExt, Full}; +use hyper::body::Incoming; +use hyper_util::rt::TokioIo; +use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; +use rmcp::transport::common::http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, HEADER_SESSION_ID, + JSON_MIME_TYPE, +}; +use rmcp::transport::streamable_http_client::{ + AuthRequiredError, InsufficientScopeError, StreamableHttpClient, StreamableHttpError, + StreamableHttpPostResponse, +}; +use sse_stream::SseStream; +use tokio::net::UnixStream; + +#[derive(Debug, thiserror::Error)] +pub enum UnixSocketError { + #[error("hyper error: {0}")] + Hyper(#[from] hyper::Error), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("HTTP error: {0}")] + Http(#[from] http::Error), + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +#[derive(Clone, Debug)] +pub struct UnixSocketHttpClient { + socket_path: Arc, + default_headers: HashMap, +} + +impl UnixSocketHttpClient { + pub fn new( + uri: &str, + raw_socket_path: &str, + mut default_headers: HashMap, + ) -> Self { + // hyper over a raw Unix socket does not set Host automatically (unlike reqwest over TCP). + // Derive it from the URI so Envoy can route to the correct upstream cluster. + // An explicitly provided Host in default_headers takes precedence. + if let Ok(parsed) = uri.parse::() { + if let Some(authority) = parsed.authority() { + if let Ok(val) = HeaderValue::from_str(authority.as_str()) { + default_headers.entry(http::header::HOST).or_insert(val); + } + } + } + Self { + socket_path: resolve_socket_path(raw_socket_path).into(), + default_headers, + } + } +} + +/// Converts the `@`-prefixed abstract socket notation to the null-byte prefix +/// expected by the Linux kernel. Filesystem socket paths are returned unchanged. +fn resolve_socket_path(raw: &str) -> String { + if let Some(name) = raw.strip_prefix('@') { + format!("\0{name}") + } else { + raw.to_string() + } +} + +async fn connect_unix(socket_path: &str) -> Result { + #[cfg(target_os = "linux")] + if let Some(abstract_name) = socket_path.strip_prefix('\0') { + // tokio::net::UnixStream has no connect_addr; use std via spawn_blocking + // to avoid blocking a tokio worker thread during the connect syscall + let abstract_name = abstract_name.to_string(); + let std_stream = tokio::task::spawn_blocking(move || { + use std::os::linux::net::SocketAddrExt; + let addr = std::os::unix::net::SocketAddr::from_abstract_name(&abstract_name)?; + let stream = std::os::unix::net::UnixStream::connect_addr(&addr)?; + stream.set_nonblocking(true)?; + Ok::<_, std::io::Error>(stream) + }) + .await + .map_err(std::io::Error::other)??; + return UnixStream::from_std(std_stream); + } + + UnixStream::connect(socket_path).await +} + +/// Opens a new Unix socket connection and sends the HTTP request. +/// One connection per request — simple and correct. The sidecar proxy +/// handles connection pooling on its end if needed. +async fn send_http_request( + socket_path: &str, + request: Request>, +) -> Result, UnixSocketError> { + let stream = connect_unix(socket_path).await?; + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + + tokio::spawn(async move { + if let Err(e) = conn.await { + tracing::warn!("unix socket HTTP/1.1 connection error: {e}"); + } + }); + + Ok(sender.send_request(request).await?) +} + +/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value. +/// Mirrors the private helper in rmcp's reqwest implementation. +fn extract_scope_from_header(header: &str) -> Option { + let header_lowercase = header.to_ascii_lowercase(); + let scope_key = "scope="; + let pos = header_lowercase.find(scope_key)?; + let value_slice = header.get(pos + scope_key.len()..)?; + if let Some(stripped) = value_slice.strip_prefix('"') { + let end = stripped.find('"')?; + stripped.get(..end).map(str::to_string) + } else { + let end = value_slice + .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) + .unwrap_or(value_slice.len()); + if end > 0 { + value_slice.get(..end).map(str::to_string) + } else { + None + } + } +} + +/// Applies custom headers to a request builder, rejecting reserved headers +/// except `MCP-Protocol-Version` (which the worker injects after init). +fn apply_custom_headers( + mut builder: http::request::Builder, + custom_headers: HashMap, +) -> Result> { + const RESERVED: &[&str] = &[ + "accept", + HEADER_SESSION_ID, + HEADER_MCP_PROTOCOL_VERSION, + HEADER_LAST_EVENT_ID, + ]; + for (name, value) in custom_headers { + if RESERVED + .iter() + .any(|&r| name.as_str().eq_ignore_ascii_case(r)) + { + if name + .as_str() + .eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION) + { + builder = builder.header(name, value); + continue; + } + return Err(StreamableHttpError::ReservedHeaderConflict( + name.to_string(), + )); + } + builder = builder.header(name, value); + } + Ok(builder) +} + +impl StreamableHttpClient for UnixSocketHttpClient { + type Error = UnixSocketError; + + async fn post_message( + &self, + uri: Arc, + message: ClientJsonRpcMessage, + session_id: Option>, + auth_header: Option, + custom_headers: HashMap, + ) -> Result> { + let json_body = serde_json::to_string(&message) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Json(e)))?; + + let mut builder = Request::builder() + .method(Method::POST) + .uri(uri.as_ref()) + .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) + .header( + http::header::ACCEPT, + format!("{EVENT_STREAM_MIME_TYPE}, {JSON_MIME_TYPE}"), + ); + + for (name, value) in &self.default_headers { + builder = builder.header(name.clone(), value.clone()); + } + + if let Some(auth) = auth_header { + builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); + } + + builder = apply_custom_headers(builder, custom_headers)?; + + if let Some(sid) = session_id { + builder = builder.header(HEADER_SESSION_ID, sid.as_ref()); + } + + let request = builder + .body(Full::new(Bytes::from(json_body))) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; + + let response = send_http_request(&self.socket_path, request) + .await + .map_err(StreamableHttpError::Client)?; + + let status = response.status(); + + if status == StatusCode::UNAUTHORIZED { + if let Some(header) = response.headers().get(http::header::WWW_AUTHENTICATE) { + let www_authenticate_header = header.to_str().unwrap_or_default().to_string(); + return Err(StreamableHttpError::AuthRequired(AuthRequiredError { + www_authenticate_header, + })); + } + } + + if status == StatusCode::FORBIDDEN { + if let Some(header) = response.headers().get(http::header::WWW_AUTHENTICATE) { + let header_str = header.to_str().unwrap_or_default(); + return Err(StreamableHttpError::InsufficientScope( + InsufficientScopeError { + www_authenticate_header: header_str.to_string(), + required_scope: extract_scope_from_header(header_str), + }, + )); + } + } + + if matches!(status, StatusCode::ACCEPTED | StatusCode::NO_CONTENT) { + return Ok(StreamableHttpPostResponse::Accepted); + } + + if !status.is_success() { + return Err(StreamableHttpError::UnexpectedServerResponse( + format!("post_message returned {}", status).into(), + )); + } + + let session_id = response + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let content_type = response.headers().get(http::header::CONTENT_TYPE).cloned(); + + match content_type { + Some(ref ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { + let sse_stream = SseStream::new(response.into_body()).boxed(); + Ok(StreamableHttpPostResponse::Sse(sse_stream, session_id)) + } + Some(ref ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => { + let body = response + .into_body() + .collect() + .await + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Hyper(e)))? + .to_bytes(); + let message: ServerJsonRpcMessage = serde_json::from_slice(&body) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Json(e)))?; + Ok(StreamableHttpPostResponse::Json(message, session_id)) + } + _ => Err(StreamableHttpError::UnexpectedContentType( + content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).into_owned()), + )), + } + } + + async fn delete_session( + &self, + uri: Arc, + session_id: Arc, + auth_header: Option, + custom_headers: HashMap, + ) -> Result<(), StreamableHttpError> { + let mut builder = Request::builder() + .method(Method::DELETE) + .uri(uri.as_ref()) + .header(HEADER_SESSION_ID, session_id.as_ref()); + + for (name, value) in &self.default_headers { + builder = builder.header(name.clone(), value.clone()); + } + + if let Some(auth) = auth_header { + builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); + } + + builder = apply_custom_headers(builder, custom_headers)?; + + let request = builder + .body(Full::new(Bytes::new())) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; + + let response = send_http_request(&self.socket_path, request) + .await + .map_err(StreamableHttpError::Client)?; + + // 405 means the server doesn't support session deletion — treat as success + if response.status() == StatusCode::METHOD_NOT_ALLOWED { + return Ok(()); + } + + if !response.status().is_success() { + return Err(StreamableHttpError::UnexpectedServerResponse( + format!("delete_session returned {}", response.status()).into(), + )); + } + + Ok(()) + } + + async fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_header: Option, + custom_headers: HashMap, + ) -> Result< + BoxStream<'static, Result>, + StreamableHttpError, + > { + let mut builder = Request::builder() + .method(Method::GET) + .uri(uri.as_ref()) + .header( + http::header::ACCEPT, + format!("{EVENT_STREAM_MIME_TYPE}, {JSON_MIME_TYPE}"), + ) + .header(HEADER_SESSION_ID, session_id.as_ref()); + + for (name, value) in &self.default_headers { + builder = builder.header(name.clone(), value.clone()); + } + + if let Some(last_id) = last_event_id { + builder = builder.header(HEADER_LAST_EVENT_ID, last_id); + } + + if let Some(auth) = auth_header { + builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); + } + + builder = apply_custom_headers(builder, custom_headers)?; + + let request = builder + .body(Full::new(Bytes::new())) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; + + let response = send_http_request(&self.socket_path, request) + .await + .map_err(StreamableHttpError::Client)?; + + if response.status() == StatusCode::METHOD_NOT_ALLOWED { + return Err(StreamableHttpError::ServerDoesNotSupportSse); + } + + if !response.status().is_success() { + return Err(StreamableHttpError::UnexpectedServerResponse( + format!("get_stream returned {}", response.status()).into(), + )); + } + + Ok(SseStream::new(response.into_body()).boxed()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_resolve_abstract_socket() { + assert_eq!(resolve_socket_path("@egress.sock"), "\0egress.sock"); + } + + #[test] + fn test_resolve_filesystem_socket() { + assert_eq!( + resolve_socket_path("/var/run/envoy.sock"), + "/var/run/envoy.sock" + ); + } + + #[test] + fn test_resolve_empty_abstract() { + assert_eq!(resolve_socket_path("@"), "\0"); + } + + #[test] + fn test_extract_scope_quoted() { + assert_eq!( + extract_scope_from_header(r#"Bearer realm="example", scope="read write""#), + Some("read write".to_string()) + ); + } + + #[test] + fn test_extract_scope_unquoted() { + assert_eq!( + extract_scope_from_header("Bearer scope=read"), + Some("read".to_string()) + ); + } + + #[test] + fn test_extract_scope_missing() { + assert_eq!(extract_scope_from_header("Bearer realm=\"example\""), None); + } + + #[test] + fn test_host_header_auto_derived() { + let client = UnixSocketHttpClient::new( + "http://staging.ai-app-info.gns.square/mcp", + "/var/run/envoy.sock", + HashMap::new(), + ); + let host = client.default_headers.get(&http::header::HOST).unwrap(); + assert_eq!(host, "staging.ai-app-info.gns.square"); + } + + #[test] + fn test_host_header_explicit_takes_precedence() { + let mut headers = HashMap::new(); + headers.insert( + http::header::HOST, + HeaderValue::from_static("custom.example.com"), + ); + let client = UnixSocketHttpClient::new( + "http://staging.ai-app-info.gns.square/mcp", + "/var/run/envoy.sock", + headers, + ); + let host = client.default_headers.get(&http::header::HOST).unwrap(); + assert_eq!(host, "custom.example.com"); + } +} diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index ecdb44bb4ff3..03ca36af9b42 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -1148,6 +1148,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::from([("Authorization".into(), "Bearer token".into())]), + socket: None, timeout: None, bundled: Some(false), available_tools: vec![], @@ -1169,6 +1170,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::new(), + socket: None, timeout: None, bundled: None, available_tools: vec![], diff --git a/crates/goose/src/providers/codex.rs b/crates/goose/src/providers/codex.rs index d121097c6941..3d78d3adca44 100644 --- a/crates/goose/src/providers/codex.rs +++ b/crates/goose/src/providers/codex.rs @@ -791,6 +791,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::from([("Authorization".into(), "Bearer token".into())]), + socket: None, timeout: None, bundled: Some(false), available_tools: vec![], @@ -809,6 +810,7 @@ mod tests { envs: Envs::default(), env_keys: vec![], headers: HashMap::new(), + socket: None, timeout: None, bundled: None, available_tools: vec![], diff --git a/crates/goose/src/recipe/recipe_extension_adapter.rs b/crates/goose/src/recipe/recipe_extension_adapter.rs index 3b32c3e35d4f..3639b1788fd4 100644 --- a/crates/goose/src/recipe/recipe_extension_adapter.rs +++ b/crates/goose/src/recipe/recipe_extension_adapter.rs @@ -60,6 +60,8 @@ enum RecipeExtensionConfigInternal { env_keys: Vec, #[serde(default)] headers: HashMap, + #[serde(default)] + socket: Option, timeout: Option, #[serde(default)] bundled: Option, @@ -139,6 +141,7 @@ impl From for ExtensionConfig { envs, env_keys, headers, + socket, timeout, bundled, available_tools diff --git a/ui/desktop/openapi.json b/ui/desktop/openapi.json index 4a07d6bb3735..14058c1b36c0 100644 --- a/ui/desktop/openapi.json +++ b/ui/desktop/openapi.json @@ -5152,6 +5152,11 @@ "type": "string", "description": "The name used to identify this extension" }, + "socket": { + "type": "string", + "description": "Unix domain socket path to route HTTP through (e.g. \"@egress.sock\" for Envoy sidecar).\nWhen set, the physical connection goes through this socket while `uri` is used for the\nHTTP Host header and path. Useful in K8s environments where DNS only resolves via Envoy.", + "nullable": true + }, "timeout": { "type": "integer", "format": "int64", diff --git a/ui/desktop/src/api/types.gen.ts b/ui/desktop/src/api/types.gen.ts index 8938f2cf78d9..b5b5316f3f79 100644 --- a/ui/desktop/src/api/types.gen.ts +++ b/ui/desktop/src/api/types.gen.ts @@ -402,6 +402,12 @@ export type ExtensionConfig = { * The name used to identify this extension */ name: string; + /** + * Unix domain socket path to route HTTP through (e.g. "@egress.sock" for Envoy sidecar). + * When set, the physical connection goes through this socket while `uri` is used for the + * HTTP Host header and path. Useful in K8s environments where DNS only resolves via Envoy. + */ + socket?: string | null; timeout?: number | null; type: 'streamable_http'; uri: string;