Skip to content

Commit 49ad115

Browse files
feat(acp): add model selection support for session/new and session/set_model (#7112)
Signed-off-by: Adrian Cole <adrian@tetrate.io>
1 parent 4abf91e commit 49ad115

12 files changed

Lines changed: 332 additions & 36 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/goose-acp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ goose = { path = "../goose" }
1919
goose-mcp = { path = "../goose-mcp" }
2020
rmcp = { workspace = true }
2121
sacp = "10.1.0"
22+
agent-client-protocol-schema = { version = "0.10", features = ["unstable_session_model"] }
2223
anyhow = { workspace = true }
2324
tokio = { workspace = true }
2425
tokio-util = { version = "0.7.15", features = ["compat", "rt"] }

crates/goose-acp/src/server.rs

Lines changed: 163 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use goose::conversation::Conversation;
1313
use goose::mcp_utils::ToolResult;
1414
use goose::permission::permission_confirmation::PrincipalType;
1515
use goose::permission::{Permission, PermissionConfirmation};
16+
use goose::providers::base::Provider;
1617
use goose::providers::provider_registry::ProviderConstructor;
1718
use goose::session::session_manager::SessionType;
1819
use goose::session::{Session, SessionManager};
@@ -21,12 +22,13 @@ use sacp::schema::{
2122
AgentCapabilities, AuthMethod, AuthenticateRequest, AuthenticateResponse, BlobResourceContents,
2223
CancelNotification, Content, ContentBlock, ContentChunk, EmbeddedResource,
2324
EmbeddedResourceResource, ImageContent, InitializeRequest, InitializeResponse,
24-
LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, NewSessionRequest,
25-
NewSessionResponse, PermissionOption, PermissionOptionKind, PromptCapabilities, PromptRequest,
26-
PromptResponse, RequestPermissionOutcome, RequestPermissionRequest, ResourceLink, SessionId,
27-
SessionNotification, SessionUpdate, StopReason, TextContent, TextResourceContents, ToolCall,
28-
ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus, ToolCallUpdate,
29-
ToolCallUpdateFields, ToolKind,
25+
LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, ModelId, ModelInfo,
26+
NewSessionRequest, NewSessionResponse, PermissionOption, PermissionOptionKind,
27+
PromptCapabilities, PromptRequest, PromptResponse, RequestPermissionOutcome,
28+
RequestPermissionRequest, ResourceLink, SessionId, SessionModelState, SessionNotification,
29+
SessionUpdate, SetSessionModelRequest, SetSessionModelResponse, StopReason, TextContent,
30+
TextResourceContents, ToolCall, ToolCallContent, ToolCallId, ToolCallLocation, ToolCallStatus,
31+
ToolCallUpdate, ToolCallUpdateFields, ToolKind,
3032
};
3133
use sacp::{AgentToClient, ByteStreams, Handled, JrConnectionCx, JrMessageHandler, MessageCx};
3234
use std::collections::HashMap;
@@ -48,7 +50,7 @@ pub struct GooseAcpAgent {
4850
agent: Arc<Agent>,
4951
provider_factory: ProviderConstructor,
5052
config_dir: std::path::PathBuf,
51-
provider_initialized: tokio::sync::OnceCell<String>,
53+
provider_initialized: tokio::sync::OnceCell<Arc<dyn Provider>>,
5254
}
5355

5456
fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConfig, String> {
@@ -266,6 +268,22 @@ async fn add_extensions(agent: &Agent, extensions: Vec<ExtensionConfig>) {
266268
}
267269
}
268270

271+
async fn build_model_state(
272+
provider: &dyn Provider,
273+
current_model: &str,
274+
) -> Result<SessionModelState, sacp::Error> {
275+
let models = provider.fetch_recommended_models().await.map_err(|e| {
276+
sacp::Error::internal_error().data(format!("Failed to fetch models: {}", e))
277+
})?;
278+
Ok(SessionModelState::new(
279+
ModelId::new(current_model),
280+
models
281+
.iter()
282+
.map(|name| ModelInfo::new(ModelId::new(&**name), &**name))
283+
.collect(),
284+
))
285+
}
286+
269287
impl GooseAcpAgent {
270288
pub fn permission_manager(&self) -> Arc<PermissionManager> {
271289
Arc::clone(&self.agent.config.permission_manager)
@@ -682,7 +700,7 @@ impl GooseAcpAgent {
682700
.map_err(|e| {
683701
sacp::Error::internal_error().data(format!("Failed to create session: {}", e))
684702
})?;
685-
self.ensure_provider(&goose_session).await.map_err(|e| {
703+
let provider = self.ensure_provider(&goose_session).await.map_err(|e| {
686704
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
687705
})?;
688706

@@ -715,25 +733,28 @@ impl GooseAcpAgent {
715733
"Session started"
716734
);
717735

718-
Ok(NewSessionResponse::new(SessionId::new(goose_session.id)))
736+
let model_state =
737+
build_model_state(&**provider, &provider.get_model_config().model_name).await?;
738+
739+
Ok(NewSessionResponse::new(SessionId::new(goose_session.id)).models(model_state))
719740
}
720741

721-
// Called at most once via OnceCell; returns the model_id used.
722-
async fn create_provider(&self, session: &Session) -> Result<String> {
742+
async fn create_provider(&self, session: &Session) -> Result<Arc<dyn Provider>> {
723743
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
724744
let config = Config::new(&config_path, "goose")?;
725745
let model_id = config.get_goose_model()?;
726746
let model_config = goose::model::ModelConfig::new(&model_id)?;
727747
let provider = (self.provider_factory)(model_config).await?;
728-
self.agent.update_provider(provider, &session.id).await?;
729-
Ok(model_id)
748+
self.agent
749+
.update_provider(provider.clone(), &session.id)
750+
.await?;
751+
Ok(provider)
730752
}
731753

732-
async fn ensure_provider(&self, session: &Session) -> Result<()> {
754+
async fn ensure_provider(&self, session: &Session) -> Result<&Arc<dyn Provider>> {
733755
self.provider_initialized
734756
.get_or_try_init(|| self.create_provider(session))
735-
.await?;
736-
Ok(())
757+
.await
737758
}
738759

739760
async fn on_load_session(
@@ -750,7 +771,7 @@ impl GooseAcpAgent {
750771
sacp::Error::invalid_params()
751772
.data(format!("Failed to load session {}: {}", session_id, e))
752773
})?;
753-
self.ensure_provider(&goose_session).await.map_err(|e| {
774+
let provider = self.ensure_provider(&goose_session).await.map_err(|e| {
754775
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
755776
})?;
756777

@@ -830,7 +851,10 @@ impl GooseAcpAgent {
830851
"Session loaded"
831852
);
832853

833-
Ok(LoadSessionResponse::new())
854+
let model_state =
855+
build_model_state(&**provider, &provider.get_model_config().model_name).await?;
856+
857+
Ok(LoadSessionResponse::new().models(model_state))
834858
}
835859

836860
async fn on_prompt(
@@ -928,6 +952,28 @@ impl GooseAcpAgent {
928952

929953
Ok(())
930954
}
955+
956+
async fn on_set_model(
957+
&self,
958+
session_id: &str,
959+
model_id: &str,
960+
) -> Result<SetSessionModelResponse, sacp::Error> {
961+
let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| {
962+
sacp::Error::internal_error().data(format!("Invalid model config: {}", e))
963+
})?;
964+
let provider = (self.provider_factory)(model_config).await.map_err(|e| {
965+
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
966+
})?;
967+
self.agent
968+
.update_provider(provider, session_id)
969+
.await
970+
.map_err(|e| {
971+
sacp::Error::internal_error().data(format!("Failed to update provider: {}", e))
972+
})?;
973+
974+
info!(session_id = %session_id, model_id = %model_id, "Model switched");
975+
Ok(SetSessionModelResponse::new())
976+
}
931977
}
932978

933979
pub struct GooseAcpHandler {
@@ -997,7 +1043,30 @@ impl JrMessageHandler for GooseAcpHandler {
9971043
self.agent.on_cancel(notif).await
9981044
})
9991045
.await
1000-
.done()
1046+
// HACK: sacp doesn't support session/set_model yet, so we handle it as untyped JSON.
1047+
.otherwise({
1048+
let agent = self.agent.clone();
1049+
|message: MessageCx| async move {
1050+
match message {
1051+
MessageCx::Request(req, request_cx)
1052+
if req.method == "session/set_model" =>
1053+
{
1054+
let params: SetSessionModelRequest = serde_json::from_value(req.params)
1055+
.map_err(|e| sacp::Error::invalid_params().data(e.to_string()))?;
1056+
let resp = agent
1057+
.on_set_model(&params.session_id.0, &params.model_id.0)
1058+
.await?;
1059+
let json = serde_json::to_value(resp)
1060+
.map_err(|e| sacp::Error::internal_error().data(e.to_string()))?;
1061+
request_cx.respond(json)?;
1062+
Ok(())
1063+
}
1064+
_ => Err(sacp::Error::method_not_found()),
1065+
}
1066+
}
1067+
})
1068+
.await
1069+
.map(|()| Handled::Yes)
10011070
}
10021071
}
10031072

@@ -1189,4 +1258,79 @@ print(\"hello, world\")
11891258
) {
11901259
assert_eq!(outcome_to_confirmation(&input), expected);
11911260
}
1261+
1262+
use goose::providers::errors::ProviderError;
1263+
1264+
struct MockModelProvider {
1265+
models: Result<Vec<String>, ProviderError>,
1266+
}
1267+
1268+
#[async_trait::async_trait]
1269+
impl goose::providers::base::Provider for MockModelProvider {
1270+
fn get_name(&self) -> &str {
1271+
"mock"
1272+
}
1273+
1274+
async fn complete_with_model(
1275+
&self,
1276+
_session_id: Option<&str>,
1277+
_model_config: &goose::model::ModelConfig,
1278+
_system: &str,
1279+
_messages: &[goose::conversation::message::Message],
1280+
_tools: &[rmcp::model::Tool],
1281+
) -> Result<
1282+
(
1283+
goose::conversation::message::Message,
1284+
goose::providers::base::ProviderUsage,
1285+
),
1286+
ProviderError,
1287+
> {
1288+
unimplemented!()
1289+
}
1290+
1291+
fn get_model_config(&self) -> goose::model::ModelConfig {
1292+
goose::model::ModelConfig::new_or_fail("unused")
1293+
}
1294+
1295+
async fn fetch_recommended_models(&self) -> Result<Vec<String>, ProviderError> {
1296+
self.models.clone()
1297+
}
1298+
}
1299+
1300+
#[test_case(
1301+
"model-a", Ok(vec!["model-a".into(), "model-b".into()])
1302+
=> Ok(SessionModelState::new(
1303+
ModelId::new("model-a"),
1304+
vec![ModelInfo::new(ModelId::new("model-a"), "model-a"),
1305+
ModelInfo::new(ModelId::new("model-b"), "model-b")],
1306+
))
1307+
; "returns current and available models"
1308+
)]
1309+
#[test_case(
1310+
"model-a", Ok(vec![])
1311+
=> Ok(SessionModelState::new(ModelId::new("model-a"), vec![]))
1312+
; "empty model list"
1313+
)]
1314+
#[test_case(
1315+
"model-a", Err(ProviderError::ExecutionError("fail".into()))
1316+
=> matches Err(_)
1317+
; "fetch error propagates"
1318+
)]
1319+
#[test_case(
1320+
"switched-model", Ok(vec!["model-a".into(), "switched-model".into()])
1321+
=> Ok(SessionModelState::new(
1322+
ModelId::new("switched-model"),
1323+
vec![ModelInfo::new(ModelId::new("model-a"), "model-a"),
1324+
ModelInfo::new(ModelId::new("switched-model"), "switched-model")],
1325+
))
1326+
; "current model reflects switched model"
1327+
)]
1328+
#[tokio::test]
1329+
async fn test_build_model_state(
1330+
current_model: &str,
1331+
models: Result<Vec<String>, ProviderError>,
1332+
) -> Result<SessionModelState, sacp::Error> {
1333+
let provider = MockModelProvider { models };
1334+
build_model_state(&provider, current_model).await
1335+
}
11921336
}

0 commit comments

Comments
 (0)