@@ -13,6 +13,7 @@ use goose::conversation::Conversation;
1313use goose:: mcp_utils:: ToolResult ;
1414use goose:: permission:: permission_confirmation:: PrincipalType ;
1515use goose:: permission:: { Permission , PermissionConfirmation } ;
16+ use goose:: providers:: base:: Provider ;
1617use goose:: providers:: provider_registry:: ProviderConstructor ;
1718use goose:: session:: session_manager:: SessionType ;
1819use goose:: session:: { Session , SessionManager } ;
@@ -21,12 +22,13 @@ use sacp::schema::{
2122 AgentCapabilities , AuthMethod , AuthenticateRequest , AuthenticateResponse , BlobResourceContents ,
2223 CancelNotification , Content , ContentBlock , ContentChunk , EmbeddedResource ,
2324 EmbeddedResourceResource , ImageContent , InitializeRequest , InitializeResponse ,
24- LoadSessionRequest , LoadSessionResponse , McpCapabilities , McpServer , NewSessionRequest ,
25- NewSessionResponse , PermissionOption , PermissionOptionKind , PromptCapabilities , PromptRequest ,
26- PromptResponse , RequestPermissionOutcome , RequestPermissionRequest , ResourceLink , SessionId ,
27- SessionNotification , SessionUpdate , StopReason , TextContent , TextResourceContents , ToolCall ,
28- ToolCallContent , ToolCallId , ToolCallLocation , ToolCallStatus , ToolCallUpdate ,
29- ToolCallUpdateFields , ToolKind ,
25+ LoadSessionRequest , LoadSessionResponse , McpCapabilities , McpServer , ModelId , ModelInfo ,
26+ NewSessionRequest , NewSessionResponse , PermissionOption , PermissionOptionKind ,
27+ PromptCapabilities , PromptRequest , PromptResponse , RequestPermissionOutcome ,
28+ RequestPermissionRequest , ResourceLink , SessionId , SessionModelState , SessionNotification ,
29+ SessionUpdate , SetSessionModelRequest , SetSessionModelResponse , StopReason , TextContent ,
30+ TextResourceContents , ToolCall , ToolCallContent , ToolCallId , ToolCallLocation , ToolCallStatus ,
31+ ToolCallUpdate , ToolCallUpdateFields , ToolKind ,
3032} ;
3133use sacp:: { AgentToClient , ByteStreams , Handled , JrConnectionCx , JrMessageHandler , MessageCx } ;
3234use std:: collections:: HashMap ;
@@ -48,7 +50,7 @@ pub struct GooseAcpAgent {
4850 agent : Arc < Agent > ,
4951 provider_factory : ProviderConstructor ,
5052 config_dir : std:: path:: PathBuf ,
51- provider_initialized : tokio:: sync:: OnceCell < String > ,
53+ provider_initialized : tokio:: sync:: OnceCell < Arc < dyn Provider > > ,
5254}
5355
5456fn mcp_server_to_extension_config ( mcp_server : McpServer ) -> Result < ExtensionConfig , String > {
@@ -266,6 +268,22 @@ async fn add_extensions(agent: &Agent, extensions: Vec<ExtensionConfig>) {
266268 }
267269}
268270
271+ async fn build_model_state (
272+ provider : & dyn Provider ,
273+ current_model : & str ,
274+ ) -> Result < SessionModelState , sacp:: Error > {
275+ let models = provider. fetch_recommended_models ( ) . await . map_err ( |e| {
276+ sacp:: Error :: internal_error ( ) . data ( format ! ( "Failed to fetch models: {}" , e) )
277+ } ) ?;
278+ Ok ( SessionModelState :: new (
279+ ModelId :: new ( current_model) ,
280+ models
281+ . iter ( )
282+ . map ( |name| ModelInfo :: new ( ModelId :: new ( & * * name) , & * * name) )
283+ . collect ( ) ,
284+ ) )
285+ }
286+
269287impl GooseAcpAgent {
270288 pub fn permission_manager ( & self ) -> Arc < PermissionManager > {
271289 Arc :: clone ( & self . agent . config . permission_manager )
@@ -682,7 +700,7 @@ impl GooseAcpAgent {
682700 . map_err ( |e| {
683701 sacp:: Error :: internal_error ( ) . data ( format ! ( "Failed to create session: {}" , e) )
684702 } ) ?;
685- self . ensure_provider ( & goose_session) . await . map_err ( |e| {
703+ let provider = self . ensure_provider ( & goose_session) . await . map_err ( |e| {
686704 sacp:: Error :: internal_error ( ) . data ( format ! ( "Failed to set provider: {}" , e) )
687705 } ) ?;
688706
@@ -715,25 +733,28 @@ impl GooseAcpAgent {
715733 "Session started"
716734 ) ;
717735
718- Ok ( NewSessionResponse :: new ( SessionId :: new ( goose_session. id ) ) )
736+ let model_state =
737+ build_model_state ( & * * provider, & provider. get_model_config ( ) . model_name ) . await ?;
738+
739+ Ok ( NewSessionResponse :: new ( SessionId :: new ( goose_session. id ) ) . models ( model_state) )
719740 }
720741
721- // Called at most once via OnceCell; returns the model_id used.
722- async fn create_provider ( & self , session : & Session ) -> Result < String > {
742+ async fn create_provider ( & self , session : & Session ) -> Result < Arc < dyn Provider > > {
723743 let config_path = self . config_dir . join ( CONFIG_YAML_NAME ) ;
724744 let config = Config :: new ( & config_path, "goose" ) ?;
725745 let model_id = config. get_goose_model ( ) ?;
726746 let model_config = goose:: model:: ModelConfig :: new ( & model_id) ?;
727747 let provider = ( self . provider_factory ) ( model_config) . await ?;
728- self . agent . update_provider ( provider, & session. id ) . await ?;
729- Ok ( model_id)
748+ self . agent
749+ . update_provider ( provider. clone ( ) , & session. id )
750+ . await ?;
751+ Ok ( provider)
730752 }
731753
732- async fn ensure_provider ( & self , session : & Session ) -> Result < ( ) > {
754+ async fn ensure_provider ( & self , session : & Session ) -> Result < & Arc < dyn Provider > > {
733755 self . provider_initialized
734756 . get_or_try_init ( || self . create_provider ( session) )
735- . await ?;
736- Ok ( ( ) )
757+ . await
737758 }
738759
739760 async fn on_load_session (
@@ -750,7 +771,7 @@ impl GooseAcpAgent {
750771 sacp:: Error :: invalid_params ( )
751772 . data ( format ! ( "Failed to load session {}: {}" , session_id, e) )
752773 } ) ?;
753- self . ensure_provider ( & goose_session) . await . map_err ( |e| {
774+ let provider = self . ensure_provider ( & goose_session) . await . map_err ( |e| {
754775 sacp:: Error :: internal_error ( ) . data ( format ! ( "Failed to set provider: {}" , e) )
755776 } ) ?;
756777
@@ -830,7 +851,10 @@ impl GooseAcpAgent {
830851 "Session loaded"
831852 ) ;
832853
833- Ok ( LoadSessionResponse :: new ( ) )
854+ let model_state =
855+ build_model_state ( & * * provider, & provider. get_model_config ( ) . model_name ) . await ?;
856+
857+ Ok ( LoadSessionResponse :: new ( ) . models ( model_state) )
834858 }
835859
836860 async fn on_prompt (
@@ -928,6 +952,28 @@ impl GooseAcpAgent {
928952
929953 Ok ( ( ) )
930954 }
955+
956+ async fn on_set_model (
957+ & self ,
958+ session_id : & str ,
959+ model_id : & str ,
960+ ) -> Result < SetSessionModelResponse , sacp:: Error > {
961+ let model_config = goose:: model:: ModelConfig :: new ( model_id) . map_err ( |e| {
962+ sacp:: Error :: internal_error ( ) . data ( format ! ( "Invalid model config: {}" , e) )
963+ } ) ?;
964+ let provider = ( self . provider_factory ) ( model_config) . await . map_err ( |e| {
965+ sacp:: Error :: internal_error ( ) . data ( format ! ( "Failed to create provider: {}" , e) )
966+ } ) ?;
967+ self . agent
968+ . update_provider ( provider, session_id)
969+ . await
970+ . map_err ( |e| {
971+ sacp:: Error :: internal_error ( ) . data ( format ! ( "Failed to update provider: {}" , e) )
972+ } ) ?;
973+
974+ info ! ( session_id = %session_id, model_id = %model_id, "Model switched" ) ;
975+ Ok ( SetSessionModelResponse :: new ( ) )
976+ }
931977}
932978
933979pub struct GooseAcpHandler {
@@ -997,7 +1043,30 @@ impl JrMessageHandler for GooseAcpHandler {
9971043 self . agent . on_cancel ( notif) . await
9981044 } )
9991045 . await
1000- . done ( )
1046+ // HACK: sacp doesn't support session/set_model yet, so we handle it as untyped JSON.
1047+ . otherwise ( {
1048+ let agent = self . agent . clone ( ) ;
1049+ |message : MessageCx | async move {
1050+ match message {
1051+ MessageCx :: Request ( req, request_cx)
1052+ if req. method == "session/set_model" =>
1053+ {
1054+ let params: SetSessionModelRequest = serde_json:: from_value ( req. params )
1055+ . map_err ( |e| sacp:: Error :: invalid_params ( ) . data ( e. to_string ( ) ) ) ?;
1056+ let resp = agent
1057+ . on_set_model ( & params. session_id . 0 , & params. model_id . 0 )
1058+ . await ?;
1059+ let json = serde_json:: to_value ( resp)
1060+ . map_err ( |e| sacp:: Error :: internal_error ( ) . data ( e. to_string ( ) ) ) ?;
1061+ request_cx. respond ( json) ?;
1062+ Ok ( ( ) )
1063+ }
1064+ _ => Err ( sacp:: Error :: method_not_found ( ) ) ,
1065+ }
1066+ }
1067+ } )
1068+ . await
1069+ . map ( |( ) | Handled :: Yes )
10011070 }
10021071}
10031072
@@ -1189,4 +1258,79 @@ print(\"hello, world\")
11891258 ) {
11901259 assert_eq ! ( outcome_to_confirmation( & input) , expected) ;
11911260 }
1261+
1262+ use goose:: providers:: errors:: ProviderError ;
1263+
1264+ struct MockModelProvider {
1265+ models : Result < Vec < String > , ProviderError > ,
1266+ }
1267+
1268+ #[ async_trait:: async_trait]
1269+ impl goose:: providers:: base:: Provider for MockModelProvider {
1270+ fn get_name ( & self ) -> & str {
1271+ "mock"
1272+ }
1273+
1274+ async fn complete_with_model (
1275+ & self ,
1276+ _session_id : Option < & str > ,
1277+ _model_config : & goose:: model:: ModelConfig ,
1278+ _system : & str ,
1279+ _messages : & [ goose:: conversation:: message:: Message ] ,
1280+ _tools : & [ rmcp:: model:: Tool ] ,
1281+ ) -> Result <
1282+ (
1283+ goose:: conversation:: message:: Message ,
1284+ goose:: providers:: base:: ProviderUsage ,
1285+ ) ,
1286+ ProviderError ,
1287+ > {
1288+ unimplemented ! ( )
1289+ }
1290+
1291+ fn get_model_config ( & self ) -> goose:: model:: ModelConfig {
1292+ goose:: model:: ModelConfig :: new_or_fail ( "unused" )
1293+ }
1294+
1295+ async fn fetch_recommended_models ( & self ) -> Result < Vec < String > , ProviderError > {
1296+ self . models . clone ( )
1297+ }
1298+ }
1299+
1300+ #[ test_case(
1301+ "model-a" , Ok ( vec![ "model-a" . into( ) , "model-b" . into( ) ] )
1302+ => Ok ( SessionModelState :: new(
1303+ ModelId :: new( "model-a" ) ,
1304+ vec![ ModelInfo :: new( ModelId :: new( "model-a" ) , "model-a" ) ,
1305+ ModelInfo :: new( ModelId :: new( "model-b" ) , "model-b" ) ] ,
1306+ ) )
1307+ ; "returns current and available models"
1308+ ) ]
1309+ #[ test_case(
1310+ "model-a" , Ok ( vec![ ] )
1311+ => Ok ( SessionModelState :: new( ModelId :: new( "model-a" ) , vec![ ] ) )
1312+ ; "empty model list"
1313+ ) ]
1314+ #[ test_case(
1315+ "model-a" , Err ( ProviderError :: ExecutionError ( "fail" . into( ) ) )
1316+ => matches Err ( _)
1317+ ; "fetch error propagates"
1318+ ) ]
1319+ #[ test_case(
1320+ "switched-model" , Ok ( vec![ "model-a" . into( ) , "switched-model" . into( ) ] )
1321+ => Ok ( SessionModelState :: new(
1322+ ModelId :: new( "switched-model" ) ,
1323+ vec![ ModelInfo :: new( ModelId :: new( "model-a" ) , "model-a" ) ,
1324+ ModelInfo :: new( ModelId :: new( "switched-model" ) , "switched-model" ) ] ,
1325+ ) )
1326+ ; "current model reflects switched model"
1327+ ) ]
1328+ #[ tokio:: test]
1329+ async fn test_build_model_state (
1330+ current_model : & str ,
1331+ models : Result < Vec < String > , ProviderError > ,
1332+ ) -> Result < SessionModelState , sacp:: Error > {
1333+ let provider = MockModelProvider { models } ;
1334+ build_model_state ( & provider, current_model) . await
1335+ }
11921336}
0 commit comments