Skip to content

Commit 4ad0aa4

Browse files
fix(acp): per-session Agent for model isolation and load_session restore (#7115)
Signed-off-by: Adrian Cole <adrian@tetrate.io>
1 parent dc50b7b commit 4ad0aa4

7 files changed

Lines changed: 502 additions & 407 deletions

File tree

crates/goose-acp/src/server.rs

Lines changed: 95 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,24 @@ use tokio_util::sync::CancellationToken;
3939
use tracing::{debug, error, info, warn};
4040
use url::Url;
4141

42+
// Agent binds provider, extensions, and permission channels to a single session.
43+
// ACP has no session/close, so sessions accumulate until transport closes.
4244
struct GooseAcpSession {
45+
agent: Arc<Agent>,
4346
messages: Conversation,
4447
tool_requests: HashMap<String, goose::conversation::message::ToolRequest>,
4548
cancel_token: Option<CancellationToken>,
4649
}
4750

4851
pub struct GooseAcpAgent {
4952
sessions: Arc<Mutex<HashMap<String, GooseAcpSession>>>,
50-
agent: Arc<Agent>,
5153
provider_factory: ProviderConstructor,
5254
config_dir: std::path::PathBuf,
53-
provider_initialized: tokio::sync::OnceCell<Arc<dyn Provider>>,
55+
session_manager: Arc<SessionManager>,
56+
permission_manager: Arc<PermissionManager>,
57+
goose_mode: goose::config::GooseMode,
58+
disable_session_naming: bool,
59+
builtins: Vec<String>,
5460
}
5561

5662
fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConfig, String> {
@@ -286,7 +292,7 @@ async fn build_model_state(
286292

287293
impl GooseAcpAgent {
288294
pub fn permission_manager(&self) -> Arc<PermissionManager> {
289-
Arc::clone(&self.agent.config.permission_manager)
295+
Arc::clone(&self.permission_manager)
290296
}
291297

292298
pub async fn new(
@@ -300,60 +306,36 @@ impl GooseAcpAgent {
300306
let session_manager = Arc::new(SessionManager::new(data_dir));
301307
let permission_manager = Arc::new(PermissionManager::new(config_dir.clone()));
302308

303-
let agent = Agent::with_config(AgentConfig::new(
304-
Arc::clone(&session_manager),
305-
permission_manager,
306-
None,
307-
goose_mode,
308-
disable_session_naming,
309-
));
310-
311-
let agent_ptr = Arc::new(agent);
312-
313-
let config_path = config_dir.join(CONFIG_YAML_NAME);
314-
let config_file = Config::new(&config_path, "goose")?;
315-
let extensions = get_enabled_extensions_with_config(&config_file);
316-
317-
add_builtins(&agent_ptr, builtins).await;
318-
add_extensions(&agent_ptr, extensions).await;
319-
320309
Ok(Self {
321310
sessions: Arc::new(Mutex::new(HashMap::new())),
322-
agent: agent_ptr,
323311
provider_factory,
324312
config_dir,
325-
provider_initialized: tokio::sync::OnceCell::new(),
313+
session_manager,
314+
permission_manager,
315+
goose_mode,
316+
disable_session_naming,
317+
builtins,
326318
})
327319
}
328320

329-
pub async fn create_session(&self) -> Result<String> {
330-
let manager = self.agent.config.session_manager.clone();
331-
let goose_session = manager
332-
.create_session(
333-
std::env::current_dir().unwrap_or_default(),
334-
"ACP Session".to_string(),
335-
SessionType::User,
336-
)
337-
.await?;
338-
339-
self.ensure_provider(&goose_session).await?;
340-
341-
let session = GooseAcpSession {
342-
messages: Conversation::new_unvalidated(Vec::new()),
343-
tool_requests: HashMap::new(),
344-
cancel_token: None,
345-
};
346-
347-
let mut sessions = self.sessions.lock().await;
348-
sessions.insert(goose_session.id.clone(), session);
321+
async fn create_agent_for_session(&self) -> Arc<Agent> {
322+
let agent = Agent::with_config(AgentConfig::new(
323+
Arc::clone(&self.session_manager),
324+
Arc::clone(&self.permission_manager),
325+
None,
326+
self.goose_mode,
327+
self.disable_session_naming,
328+
));
329+
let agent = Arc::new(agent);
349330

350-
info!(
351-
session_id = %goose_session.id,
352-
session_type = "acp",
353-
"Session created"
354-
);
331+
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
332+
if let Ok(config_file) = Config::new(&config_path, "goose") {
333+
let extensions = get_enabled_extensions_with_config(&config_file);
334+
add_extensions(&agent, extensions).await;
335+
}
336+
add_builtins(&agent, self.builtins.clone()).await;
355337

356-
Ok(goose_session.id)
338+
agent
357339
}
358340

359341
pub async fn has_session(&self, session_id: &str) -> bool {
@@ -433,12 +415,13 @@ impl GooseAcpAgent {
433415
} = &action_required.data
434416
{
435417
self.handle_tool_permission_request(
418+
cx,
419+
&session.agent,
420+
session_id,
436421
id.clone(),
437422
tool_name.clone(),
438423
arguments.clone(),
439424
prompt.clone(),
440-
session_id,
441-
cx,
442425
)?;
443426
}
444427
}
@@ -513,17 +496,19 @@ impl GooseAcpAgent {
513496
Ok(())
514497
}
515498

499+
#[allow(clippy::too_many_arguments)]
516500
fn handle_tool_permission_request(
517501
&self,
502+
cx: &JrConnectionCx<AgentToClient>,
503+
agent: &Arc<Agent>,
504+
session_id: &SessionId,
518505
request_id: String,
519506
tool_name: String,
520507
arguments: serde_json::Map<String, serde_json::Value>,
521508
prompt: Option<String>,
522-
session_id: &SessionId,
523-
cx: &JrConnectionCx<AgentToClient>,
524509
) -> Result<(), sacp::Error> {
525510
let cx = cx.clone();
526-
let agent = self.agent.clone();
511+
let agent = agent.clone();
527512
let session_id = session_id.clone();
528513

529514
let formatted_name = format_tool_name(&tool_name);
@@ -689,8 +674,8 @@ impl GooseAcpAgent {
689674
) -> Result<NewSessionResponse, sacp::Error> {
690675
debug!(?args, "new session request");
691676

692-
let manager = self.agent.config.session_manager.clone();
693-
let goose_session = manager
677+
let goose_session = self
678+
.session_manager
694679
.create_session(
695680
args.cwd.clone(),
696681
"ACP Session".to_string(),
@@ -700,9 +685,14 @@ impl GooseAcpAgent {
700685
.map_err(|e| {
701686
sacp::Error::internal_error().data(format!("Failed to create session: {}", e))
702687
})?;
703-
let provider = self.ensure_provider(&goose_session).await.map_err(|e| {
704-
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
705-
})?;
688+
689+
let agent = self.create_agent_for_session().await;
690+
let provider = self
691+
.init_provider(&agent, &goose_session)
692+
.await
693+
.map_err(|e| {
694+
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
695+
})?;
706696

707697
for mcp_server in args.mcp_servers {
708698
let config = match mcp_server_to_extension_config(mcp_server) {
@@ -712,13 +702,14 @@ impl GooseAcpAgent {
712702
}
713703
};
714704
let name = config.name().to_string();
715-
if let Err(e) = self.agent.add_extension(config, &goose_session.id).await {
705+
if let Err(e) = agent.add_extension(config, &goose_session.id).await {
716706
return Err(sacp::Error::internal_error()
717707
.data(format!("Failed to add MCP server '{}': {}", name, e)));
718708
}
719709
}
720710

721711
let session = GooseAcpSession {
712+
agent,
722713
messages: Conversation::new_unvalidated(Vec::new()),
723714
tool_requests: HashMap::new(),
724715
cancel_token: None,
@@ -734,29 +725,26 @@ impl GooseAcpAgent {
734725
);
735726

736727
let model_state =
737-
build_model_state(&**provider, &provider.get_model_config().model_name).await?;
728+
build_model_state(&*provider, &provider.get_model_config().model_name).await?;
738729

739730
Ok(NewSessionResponse::new(SessionId::new(goose_session.id)).models(model_state))
740731
}
741732

742-
async fn create_provider(&self, session: &Session) -> Result<Arc<dyn Provider>> {
743-
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
744-
let config = Config::new(&config_path, "goose")?;
745-
let model_id = config.get_goose_model()?;
746-
let model_config = goose::model::ModelConfig::new(&model_id)?;
733+
async fn init_provider(&self, agent: &Agent, session: &Session) -> Result<Arc<dyn Provider>> {
734+
let model_config = match &session.model_config {
735+
Some(config) => config.clone(),
736+
None => {
737+
let config_path = self.config_dir.join(CONFIG_YAML_NAME);
738+
let config = Config::new(&config_path, "goose")?;
739+
let model_id = config.get_goose_model()?;
740+
goose::model::ModelConfig::new(&model_id)?
741+
}
742+
};
747743
let provider = (self.provider_factory)(model_config).await?;
748-
self.agent
749-
.update_provider(provider.clone(), &session.id)
750-
.await?;
744+
agent.update_provider(provider.clone(), &session.id).await?;
751745
Ok(provider)
752746
}
753747

754-
async fn ensure_provider(&self, session: &Session) -> Result<&Arc<dyn Provider>> {
755-
self.provider_initialized
756-
.get_or_try_init(|| self.create_provider(session))
757-
.await
758-
}
759-
760748
async fn on_load_session(
761749
&self,
762750
args: LoadSessionRequest,
@@ -766,21 +754,29 @@ impl GooseAcpAgent {
766754

767755
let session_id = args.session_id.0.to_string();
768756

769-
let manager = self.agent.config.session_manager.clone();
770-
let goose_session = manager.get_session(&session_id, true).await.map_err(|e| {
771-
sacp::Error::invalid_params()
772-
.data(format!("Failed to load session {}: {}", session_id, e))
773-
})?;
774-
let provider = self.ensure_provider(&goose_session).await.map_err(|e| {
775-
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
776-
})?;
757+
let goose_session = self
758+
.session_manager
759+
.get_session(&session_id, true)
760+
.await
761+
.map_err(|e| {
762+
sacp::Error::invalid_params()
763+
.data(format!("Failed to load session {}: {}", session_id, e))
764+
})?;
765+
766+
let agent = self.create_agent_for_session().await;
767+
let provider = self
768+
.init_provider(&agent, &goose_session)
769+
.await
770+
.map_err(|e| {
771+
sacp::Error::internal_error().data(format!("Failed to set provider: {}", e))
772+
})?;
777773

778774
let conversation = goose_session.conversation.ok_or_else(|| {
779775
sacp::Error::internal_error()
780776
.data(format!("Session {} has no conversation data", session_id))
781777
})?;
782778

783-
manager
779+
self.session_manager
784780
.update(&session_id)
785781
.working_dir(args.cwd.clone())
786782
.apply()
@@ -791,6 +787,7 @@ impl GooseAcpAgent {
791787
})?;
792788

793789
let mut session = GooseAcpSession {
790+
agent,
794791
messages: conversation.clone(),
795792
tool_requests: HashMap::new(),
796793
cancel_token: None,
@@ -852,7 +849,7 @@ impl GooseAcpAgent {
852849
);
853850

854851
let model_state =
855-
build_model_state(&**provider, &provider.get_model_config().model_name).await?;
852+
build_model_state(&*provider, &provider.get_model_config().model_name).await?;
856853

857854
Ok(LoadSessionResponse::new().models(model_state))
858855
}
@@ -865,13 +862,14 @@ impl GooseAcpAgent {
865862
let session_id = args.session_id.0.to_string();
866863
let cancel_token = CancellationToken::new();
867864

868-
{
865+
let agent = {
869866
let mut sessions = self.sessions.lock().await;
870867
let session = sessions.get_mut(&session_id).ok_or_else(|| {
871868
sacp::Error::invalid_params().data(format!("Session not found: {}", session_id))
872869
})?;
873870
session.cancel_token = Some(cancel_token.clone());
874-
}
871+
session.agent.clone()
872+
};
875873

876874
let user_message = self.convert_acp_prompt_to_message(args.prompt);
877875

@@ -882,8 +880,7 @@ impl GooseAcpAgent {
882880
retry_config: None,
883881
};
884882

885-
let mut stream = self
886-
.agent
883+
let mut stream = agent
887884
.reply(user_message, session_config, Some(cancel_token.clone()))
888885
.await
889886
.map_err(|e| {
@@ -959,12 +956,20 @@ impl GooseAcpAgent {
959956
model_id: &str,
960957
) -> Result<SetSessionModelResponse, sacp::Error> {
961958
let model_config = goose::model::ModelConfig::new(model_id).map_err(|e| {
962-
sacp::Error::internal_error().data(format!("Invalid model config: {}", e))
959+
sacp::Error::invalid_params().data(format!("Invalid model config: {}", e))
963960
})?;
964961
let provider = (self.provider_factory)(model_config).await.map_err(|e| {
965962
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
966963
})?;
967-
self.agent
964+
965+
let agent = {
966+
let sessions = self.sessions.lock().await;
967+
let session = sessions.get(session_id).ok_or_else(|| {
968+
sacp::Error::invalid_params().data(format!("Session not found: {}", session_id))
969+
})?;
970+
session.agent.clone()
971+
};
972+
agent
968973
.update_provider(provider, session_id)
969974
.await
970975
.map_err(|e| {

0 commit comments

Comments
 (0)