diff --git a/crates/goose-acp/src/custom_requests.rs b/crates/goose-acp/src/custom_requests.rs index 816d7bac9e86..8cf3d7cbc5dc 100644 --- a/crates/goose-acp/src/custom_requests.rs +++ b/crates/goose-acp/src/custom_requests.rs @@ -119,6 +119,24 @@ pub struct GetExtensionsResponse { pub warnings: Vec, } +/// Apply system prompt instructions to an active session. +/// Equivalent to POST /agent/update_from_session in the HTTP API. +#[derive(Debug, Deserialize, JsonSchema)] +pub struct SetSessionInstructionsRequest { + pub session_id: String, + /// Instructions to prepend to the agent's system prompt (e.g. channel context, persona). + pub instructions: String, +} + +/// Health check. +#[derive(Debug, Deserialize, JsonSchema)] +pub struct HealthRequest {} + +#[derive(Debug, Serialize, JsonSchema)] +pub struct HealthResponse { + pub status: String, +} + /// Empty success response for operations that return no data. #[derive(Debug, Serialize, JsonSchema)] pub struct EmptyResponse {} diff --git a/crates/goose-acp/src/server.rs b/crates/goose-acp/src/server.rs index 9ebd0aba71d8..70b45b0c3ac7 100644 --- a/crates/goose-acp/src/server.rs +++ b/crates/goose-acp/src/server.rs @@ -15,6 +15,7 @@ use goose::mcp_utils::ToolResult; use goose::permission::permission_confirmation::PrincipalType; use goose::permission::{Permission, PermissionConfirmation}; use goose::providers::base::Provider; +use goose::providers::create as create_provider; use goose::providers::provider_registry::ProviderConstructor; use goose::session::session_manager::SessionType; use goose::session::{Session, SessionManager}; @@ -971,24 +972,38 @@ impl GooseAcpAgent { &self, session_id: &str, model_id: &str, + provider_override: Option<&str>, ) -> Result { let config_path = self.config_dir.join(CONFIG_YAML_NAME); let config = Config::new(&config_path, "goose").map_err(|e| { sacp::Error::internal_error().data(format!("Failed to read config: {}", e)) })?; - let provider_name = config.get_goose_provider().map_err(|_| { - sacp::Error::internal_error().data("No provider configured".to_string()) - })?; + let provider_name = if let Some(p) = provider_override { + p.to_string() + } else { + config.get_goose_provider().map_err(|_| { + sacp::Error::internal_error().data("No provider configured".to_string()) + })? + }; let model_config = goose::model::ModelConfig::new(model_id) .map_err(|e| { sacp::Error::invalid_params().data(format!("Invalid model config: {}", e)) })? .with_canonical_limits(&provider_name); - let provider = (self.provider_factory)(model_config, Vec::new()) - .await - .map_err(|e| { - sacp::Error::internal_error().data(format!("Failed to create provider: {}", e)) - })?; + let provider = if provider_override.is_some() { + // When switching providers, use the global registry (same as HTTP update_agent_provider). + create_provider(&provider_name, model_config, Vec::new()) + .await + .map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to create provider: {}", e)) + })? + } else { + (self.provider_factory)(model_config, Vec::new()) + .await + .map_err(|e| { + sacp::Error::internal_error().data(format!("Failed to create provider: {}", e)) + })? + }; let agent = { let sessions = self.sessions.lock().await; @@ -1169,6 +1184,25 @@ impl GooseAcpAgent { }) } + #[custom_method("health")] + async fn on_health(&self, _req: HealthRequest) -> Result { + Ok(HealthResponse { + status: "ok".to_string(), + }) + } + + #[custom_method("session/set_instructions")] + async fn on_set_session_instructions( + &self, + req: SetSessionInstructionsRequest, + ) -> Result { + let agent = self.get_agent_for_session(&req.session_id).await?; + agent + .extend_system_prompt("recipe".to_string(), req.instructions) + .await; + Ok(EmptyResponse {}) + } + #[custom_method("config/extensions")] async fn on_get_extensions(&self) -> Result { let extensions = goose::config::extensions::get_all_extensions(); @@ -1276,12 +1310,24 @@ impl JrMessageHandler for GooseAcpHandler { MessageCx::Request(req, request_cx) if req.method == "session/set_model" => { + // Extract optional `provider` before consuming params (sacp's + // SetSessionModelRequest doesn't have this field). + let provider_override = req + .params + .get("provider") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + .map(String::from); let params: SetSessionModelRequest = serde_json::from_value(req.params).map_err(|e| { sacp::Error::invalid_params().data(e.to_string()) })?; let resp = agent - .on_set_model(¶ms.session_id.0, ¶ms.model_id.0) + .on_set_model( + ¶ms.session_id.0, + ¶ms.model_id.0, + provider_override.as_deref(), + ) .await?; let json = serde_json::to_value(resp).map_err(|e| { sacp::Error::internal_error().data(e.to_string()) diff --git a/crates/goose-acp/tests/custom_requests_test.rs b/crates/goose-acp/tests/custom_requests_test.rs index 0f7c7678a929..73738b4323c7 100644 --- a/crates/goose-acp/tests/custom_requests_test.rs +++ b/crates/goose-acp/tests/custom_requests_test.rs @@ -168,3 +168,167 @@ fn test_custom_unknown_method() { assert!(result.is_err(), "expected method_not_found error"); }); } + +#[test] +fn test_custom_health() { + run_test(async { + let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await; + let conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + + let result = send_custom(conn.cx(), "_goose/health", serde_json::json!({})).await; + assert!(result.is_ok(), "expected ok, got: {:?}", result); + + let response = result.unwrap(); + assert_eq!( + response.get("status").and_then(|v| v.as_str()), + Some("ok"), + "expected status 'ok'" + ); + }); +} + +#[test] +fn test_custom_set_session_instructions() { + run_test(async { + let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await; + let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + + let (session, _models) = conn.new_session().await; + let session_id = session.session_id().0.clone(); + + let result = send_custom( + conn.cx(), + "_goose/session/set_instructions", + serde_json::json!({ + "session_id": session_id, + "instructions": "You are a helpful assistant for the #eng-platform Slack channel.", + }), + ) + .await; + assert!(result.is_ok(), "set_instructions failed: {:?}", result); + }); +} + +#[test] +fn test_custom_set_session_instructions_unknown_session() { + run_test(async { + let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await; + let conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + + let result = send_custom( + conn.cx(), + "_goose/session/set_instructions", + serde_json::json!({ + "session_id": "nonexistent-session-id", + "instructions": "some instructions", + }), + ) + .await; + assert!(result.is_err(), "expected error for unknown session"); + }); +} + +#[test] +fn test_custom_session_get_includes_token_fields() { + run_test(async { + let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await; + let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + + let (session, _models) = conn.new_session().await; + let session_id = session.session_id().0.clone(); + + let result = send_custom( + conn.cx(), + "_goose/session/get", + serde_json::json!({ "session_id": session_id }), + ) + .await; + assert!(result.is_ok(), "expected ok, got: {:?}", result); + + let response = result.unwrap(); + let returned_session = response.get("session").expect("missing 'session' field"); + + // Verify token metric fields are present (may be null for a fresh session). + assert!( + returned_session.get("input_tokens").is_some(), + "missing 'input_tokens' field" + ); + assert!( + returned_session.get("output_tokens").is_some(), + "missing 'output_tokens' field" + ); + assert!( + returned_session.get("accumulated_total_tokens").is_some(), + "missing 'accumulated_total_tokens' field" + ); + assert!( + returned_session.get("accumulated_input_tokens").is_some(), + "missing 'accumulated_input_tokens' field" + ); + assert!( + returned_session.get("accumulated_output_tokens").is_some(), + "missing 'accumulated_output_tokens' field" + ); + + // model_config contains model_name and context_limit (the slackbot reads context_limit + // via provider_config.context_limit in the HTTP API equivalent). + // For a fresh session, model_config is populated from the configured provider. + assert!( + returned_session.get("model_config").is_some(), + "missing 'model_config' field" + ); + }); +} + +#[test] +fn test_session_set_model_with_provider() { + run_test(async { + let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await; + let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + + let (session, _models) = conn.new_session().await; + let session_id = session.session_id().0.clone(); + + // Switch model without specifying provider (reads from config — same as before). + // Uses camelCase per sacp's SetSessionModelRequest field naming. + let result = send_custom( + conn.cx(), + "session/set_model", + serde_json::json!({ + "sessionId": session_id, + "modelId": "gpt-4o", + }), + ) + .await; + assert!(result.is_ok(), "set_model failed: {:?}", result); + }); +} + +#[test] +fn test_session_set_model_with_explicit_provider() { + run_test(async { + let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await; + let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await; + + let (session, _models) = conn.new_session().await; + let session_id = session.session_id().0.clone(); + + // Switch provider + model explicitly — the extra `provider` field is extracted from raw + // params before the sacp-typed parse. + let result = send_custom( + conn.cx(), + "session/set_model", + serde_json::json!({ + "sessionId": session_id, + "modelId": "gpt-4o", + "provider": "openai", + }), + ) + .await; + assert!( + result.is_ok(), + "set_model with provider failed: {:?}", + result + ); + }); +}