Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions crates/goose-acp/src/custom_requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ pub struct GetExtensionsResponse {
pub warnings: Vec<String>,
}

/// Apply system prompt instructions to an active session.
/// Equivalent to POST /agent/update_from_session in the HTTP API.
#[derive(Debug, Deserialize, JsonSchema)]
pub struct SetSessionInstructionsRequest {
pub session_id: String,
/// Instructions to prepend to the agent's system prompt (e.g. channel context, persona).
pub instructions: String,
}

/// Health check.
#[derive(Debug, Deserialize, JsonSchema)]
pub struct HealthRequest {}

#[derive(Debug, Serialize, JsonSchema)]
pub struct HealthResponse {
pub status: String,
}

/// Empty success response for operations that return no data.
#[derive(Debug, Serialize, JsonSchema)]
pub struct EmptyResponse {}
64 changes: 55 additions & 9 deletions crates/goose-acp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use goose::mcp_utils::ToolResult;
use goose::permission::permission_confirmation::PrincipalType;
use goose::permission::{Permission, PermissionConfirmation};
use goose::providers::base::Provider;
use goose::providers::create as create_provider;
use goose::providers::provider_registry::ProviderConstructor;
use goose::session::session_manager::SessionType;
use goose::session::{Session, SessionManager};
Expand Down Expand Up @@ -971,24 +972,38 @@ impl GooseAcpAgent {
&self,
session_id: &str,
model_id: &str,
provider_override: Option<&str>,
) -> Result<SetSessionModelResponse, sacp::Error> {
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
let config = Config::new(&config_path, "goose").map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to read config: {}", e))
})?;
let provider_name = config.get_goose_provider().map_err(|_| {
sacp::Error::internal_error().data("No provider configured".to_string())
})?;
let provider_name = if let Some(p) = provider_override {
p.to_string()
} else {
config.get_goose_provider().map_err(|_| {
sacp::Error::internal_error().data("No provider configured".to_string())
})?
};
let model_config = goose::model::ModelConfig::new(model_id)
.map_err(|e| {
sacp::Error::invalid_params().data(format!("Invalid model config: {}", e))
})?
.with_canonical_limits(&provider_name);
let provider = (self.provider_factory)(model_config, Vec::new())
.await
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
})?;
let provider = if provider_override.is_some() {
// When switching providers, use the global registry (same as HTTP update_agent_provider).
create_provider(&provider_name, model_config, Vec::new())
.await
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
})?
} else {
(self.provider_factory)(model_config, Vec::new())
.await
.map_err(|e| {
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
})?
};

let agent = {
let sessions = self.sessions.lock().await;
Expand Down Expand Up @@ -1169,6 +1184,25 @@ impl GooseAcpAgent {
})
}

#[custom_method("health")]
async fn on_health(&self, _req: HealthRequest) -> Result<HealthResponse, sacp::Error> {
Ok(HealthResponse {
status: "ok".to_string(),
})
}

#[custom_method("session/set_instructions")]
async fn on_set_session_instructions(
&self,
req: SetSessionInstructionsRequest,
) -> Result<EmptyResponse, sacp::Error> {
let agent = self.get_agent_for_session(&req.session_id).await?;
agent
.extend_system_prompt("recipe".to_string(), req.instructions)
.await;
Ok(EmptyResponse {})
}

#[custom_method("config/extensions")]
async fn on_get_extensions(&self) -> Result<GetExtensionsResponse, sacp::Error> {
let extensions = goose::config::extensions::get_all_extensions();
Expand Down Expand Up @@ -1276,12 +1310,24 @@ impl JrMessageHandler for GooseAcpHandler {
MessageCx::Request(req, request_cx)
if req.method == "session/set_model" =>
{
// Extract optional `provider` before consuming params (sacp's
// SetSessionModelRequest doesn't have this field).
let provider_override = req
.params
.get("provider")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
.map(String::from);
let params: SetSessionModelRequest =
serde_json::from_value(req.params).map_err(|e| {
sacp::Error::invalid_params().data(e.to_string())
})?;
let resp = agent
.on_set_model(&params.session_id.0, &params.model_id.0)
.on_set_model(
&params.session_id.0,
&params.model_id.0,
provider_override.as_deref(),
)
.await?;
let json = serde_json::to_value(resp).map_err(|e| {
sacp::Error::internal_error().data(e.to_string())
Expand Down
164 changes: 164 additions & 0 deletions crates/goose-acp/tests/custom_requests_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,167 @@ fn test_custom_unknown_method() {
assert!(result.is_err(), "expected method_not_found error");
});
}

#[test]
fn test_custom_health() {
run_test(async {
let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await;
let conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await;

let result = send_custom(conn.cx(), "_goose/health", serde_json::json!({})).await;
assert!(result.is_ok(), "expected ok, got: {:?}", result);

let response = result.unwrap();
assert_eq!(
response.get("status").and_then(|v| v.as_str()),
Some("ok"),
"expected status 'ok'"
);
});
}

#[test]
fn test_custom_set_session_instructions() {
run_test(async {
let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await;
let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await;

let (session, _models) = conn.new_session().await;
let session_id = session.session_id().0.clone();

let result = send_custom(
conn.cx(),
"_goose/session/set_instructions",
serde_json::json!({
"session_id": session_id,
"instructions": "You are a helpful assistant for the #eng-platform Slack channel.",
}),
)
.await;
assert!(result.is_ok(), "set_instructions failed: {:?}", result);
});
}

#[test]
fn test_custom_set_session_instructions_unknown_session() {
run_test(async {
let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await;
let conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await;

let result = send_custom(
conn.cx(),
"_goose/session/set_instructions",
serde_json::json!({
"session_id": "nonexistent-session-id",
"instructions": "some instructions",
}),
)
.await;
assert!(result.is_err(), "expected error for unknown session");
});
}

#[test]
fn test_custom_session_get_includes_token_fields() {
run_test(async {
let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await;
let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await;

let (session, _models) = conn.new_session().await;
let session_id = session.session_id().0.clone();

let result = send_custom(
conn.cx(),
"_goose/session/get",
serde_json::json!({ "session_id": session_id }),
)
.await;
assert!(result.is_ok(), "expected ok, got: {:?}", result);

let response = result.unwrap();
let returned_session = response.get("session").expect("missing 'session' field");

// Verify token metric fields are present (may be null for a fresh session).
assert!(
returned_session.get("input_tokens").is_some(),
"missing 'input_tokens' field"
);
assert!(
returned_session.get("output_tokens").is_some(),
"missing 'output_tokens' field"
);
assert!(
returned_session.get("accumulated_total_tokens").is_some(),
"missing 'accumulated_total_tokens' field"
);
assert!(
returned_session.get("accumulated_input_tokens").is_some(),
"missing 'accumulated_input_tokens' field"
);
assert!(
returned_session.get("accumulated_output_tokens").is_some(),
"missing 'accumulated_output_tokens' field"
);

// model_config contains model_name and context_limit (the slackbot reads context_limit
// via provider_config.context_limit in the HTTP API equivalent).
// For a fresh session, model_config is populated from the configured provider.
assert!(
returned_session.get("model_config").is_some(),
"missing 'model_config' field"
);
});
}

#[test]
fn test_session_set_model_with_provider() {
run_test(async {
let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await;
let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await;

let (session, _models) = conn.new_session().await;
let session_id = session.session_id().0.clone();

// Switch model without specifying provider (reads from config — same as before).
// Uses camelCase per sacp's SetSessionModelRequest field naming.
let result = send_custom(
conn.cx(),
"session/set_model",
serde_json::json!({
"sessionId": session_id,
"modelId": "gpt-4o",
}),
)
.await;
assert!(result.is_ok(), "set_model failed: {:?}", result);
});
}

#[test]
fn test_session_set_model_with_explicit_provider() {
run_test(async {
let openai = OpenAiFixture::new(vec![], ExpectedSessionId::default()).await;
let mut conn = ClientToAgentConnection::new(TestConnectionConfig::default(), openai).await;

let (session, _models) = conn.new_session().await;
let session_id = session.session_id().0.clone();

// Switch provider + model explicitly — the extra `provider` field is extracted from raw
// params before the sacp-typed parse.
let result = send_custom(
conn.cx(),
"session/set_model",
serde_json::json!({
"sessionId": session_id,
"modelId": "gpt-4o",
"provider": "openai",
}),
)
.await;
assert!(
result.is_ok(),
"set_model with provider failed: {:?}",
result
);
});
}
Loading