diff --git a/crates/goose/src/acp/server.rs b/crates/goose/src/acp/server.rs index 0089b8a79fec..21fbdf7642de 100644 --- a/crates/goose/src/acp/server.rs +++ b/crates/goose/src/acp/server.rs @@ -7,11 +7,10 @@ use crate::agents::extension_manager::TRUSTED_TOOL_UPDATE_META_KEY; use crate::agents::mcp_client::{GooseMcpHostInfo, McpClientTrait}; use crate::agents::platform_extensions::developer::DeveloperClient; use crate::agents::{Agent, AgentConfig, ExtensionConfig, GoosePlatform, SessionConfig}; -use crate::config::base::CONFIG_YAML_NAME; use crate::config::extensions::get_enabled_extensions_with_config; use crate::config::paths::Paths; use crate::config::permission::PermissionManager; -use crate::config::{Config, GooseMode}; +use crate::config::{Config, ConfigHandle, GooseMode}; use crate::conversation::message::{ActionRequiredData, Message, MessageContent}; #[cfg(feature = "local-inference")] use crate::dictation::providers::transcribe_local; @@ -201,10 +200,10 @@ pub struct GooseAcpAgent { sessions: Arc>>, provider_factory: AcpProviderFactory, builtins: Vec, + config: ConfigHandle, client_fs_capabilities: OnceCell, client_terminal: OnceCell, client_mcp_host_info: OnceCell, - config_dir: std::path::PathBuf, session_manager: Arc, thread_manager: Arc, permission_manager: Arc, @@ -833,17 +832,6 @@ async fn resolve_provider_and_model_from_config( Ok((provider_name, model_config)) } -/// Convenience wrapper: reads config from disk, then resolves provider + model. -/// Cheap enough to call from `on_new_session` (file + registry reads, no network). -async fn resolve_provider_and_model( - config_dir: &std::path::Path, - goose_session: &Session, -) -> Result<(String, crate::model::ModelConfig), String> { - let config = - Config::new(config_dir.join(CONFIG_YAML_NAME), "goose").map_err(|e| e.to_string())?; - resolve_provider_and_model_from_config(&config, goose_session).await -} - fn build_mode_state(current_mode: GooseMode) -> Result { let mut available = Vec::with_capacity(GooseMode::VARIANTS.len()); for &name in GooseMode::VARIANTS { @@ -956,6 +944,8 @@ impl GooseAcpAgent { disable_session_naming: bool, goose_platform: GoosePlatform, ) -> Result { + let config = Config::for_config_dir(config_dir.clone())?; + let session_manager = Arc::new(SessionManager::new(data_dir)); // Eagerly initialize the SQLite pool so it's ready when providers/sessions need it. @@ -967,17 +957,17 @@ impl GooseAcpAgent { let thread_manager = Arc::new(crate::session::ThreadManager::new( session_manager.storage().clone(), )); - let permission_manager = Arc::new(PermissionManager::new(config_dir.clone())); + let permission_manager = Arc::new(PermissionManager::new(config_dir)); let provider_inventory = ProviderInventoryService::new(session_manager.storage().clone()); Ok(Self { sessions: Arc::new(Mutex::new(HashMap::new())), provider_factory, builtins, + config, client_fs_capabilities: OnceCell::new(), client_terminal: OnceCell::new(), client_mcp_host_info: OnceCell::new(), - config_dir, session_manager, thread_manager, permission_manager, @@ -988,14 +978,6 @@ impl GooseAcpAgent { }) } - fn load_config(&self) -> Result { - Config::new(self.config_dir.join(CONFIG_YAML_NAME), "goose").map_err(Into::into) - } - - fn config(&self) -> Result { - self.load_config().internal_err_ctx("Failed to read config") - } - async fn create_provider( &self, provider_name: &str, @@ -1031,135 +1013,125 @@ impl GooseAcpAgent { let mut prebuilt_provider = None; if should_refresh_inventory_for_session_init(&inventory) { - match self.load_config() { - Ok(config) => { - let ext_state = EnabledExtensionsState::extensions_or_default( - Some(&goose_session.extension_data), - &config, - ); - Config::global().invalidate_secrets_cache(); + let config = self.config.as_ref(); + let ext_state = EnabledExtensionsState::extensions_or_default( + Some(&goose_session.extension_data), + config, + ); + match self + .create_provider(provider_name, model_config.clone(), ext_state) + .await + { + Ok(provider) => { + let provider_id = provider_name.clone(); + prebuilt_provider = Some(provider.clone()); match self - .create_provider(provider_name, model_config.clone(), ext_state) + .provider_inventory + .plan_refresh_jobs(std::slice::from_ref(&provider_id)) .await { - Ok(provider) => { - let provider_id = provider_name.clone(); - prebuilt_provider = Some(provider.clone()); - match self - .provider_inventory - .plan_refresh_jobs(std::slice::from_ref(&provider_id)) - .await - { - Ok(plan) - if plan - .started - .iter() - .any(|job| job.provider_id == provider_id) => - { - let refresh_job = plan - .started - .into_iter() - .find(|job| job.provider_id == provider_id); - if let Some(refresh_job) = refresh_job { - let mut refresh_guard = self + Ok(plan) + if plan + .started + .iter() + .any(|job| job.provider_id == provider_id) => + { + let refresh_job = plan + .started + .into_iter() + .find(|job| job.provider_id == provider_id); + if let Some(refresh_job) = refresh_job { + let mut refresh_guard = + self.provider_inventory.refresh_guard(&refresh_job.identity); + let fetch_result: Result> = + match ensure_refresh_identity_current( + &provider_id, + &refresh_job.identity, + ) + .await + { + Ok(()) => match AssertUnwindSafe( + provider.fetch_recommended_models(), + ) + .catch_unwind() + .await + { + Ok(Ok(models)) => Ok(models), + Ok(Err(error)) => { + Err(anyhow::anyhow!(error.to_string())) + } + Err(_) => Err(anyhow::anyhow!( + "provider inventory refresh task panicked" + )), + }, + Err(error) => Err(error), + }; + match fetch_result { + Ok(models) => { + if let Err(error) = self .provider_inventory - .refresh_guard(&refresh_job.identity); - let fetch_result: Result> = - match ensure_refresh_identity_current( - &provider_id, + .store_refreshed_models_for_identity( &refresh_job.identity, + &models, ) .await - { - Ok(()) => match AssertUnwindSafe( - provider.fetch_recommended_models(), - ) - .catch_unwind() - .await - { - Ok(Ok(models)) => Ok(models), - Ok(Err(error)) => { - Err(anyhow::anyhow!(error.to_string())) - } - Err(_) => Err(anyhow::anyhow!( - "provider inventory refresh task panicked" - )), - }, - Err(error) => Err(error), - }; - match fetch_result { - Ok(models) => { - if let Err(error) = self - .provider_inventory - .store_refreshed_models_for_identity( - &refresh_job.identity, - &models, - ) - .await - { - warn!( - provider = %provider_id, - error = %error, - "failed to store refreshed provider inventory during session init" - ); - } else { - refresh_guard.complete(); - } - } - Err(error) => { - let error_message = error.to_string(); - if let Err(store_error) = self - .provider_inventory - .store_refresh_error_for_identity( - &refresh_job.identity, - error_message.clone(), - ) - .await - { - warn!( - provider = %provider_id, - error = %store_error, - "failed to store provider inventory refresh error during session init" - ); - } else { - refresh_guard.complete(); - } - warn!( - provider = %provider_id, - error = %error_message, - "provider inventory refresh failed during session init" - ); - } + { + warn!( + provider = %provider_id, + error = %error, + "failed to store refreshed provider inventory during session init" + ); + } else { + refresh_guard.complete(); + } + } + Err(error) => { + let error_message = error.to_string(); + if let Err(store_error) = self + .provider_inventory + .store_refresh_error_for_identity( + &refresh_job.identity, + error_message.clone(), + ) + .await + { + warn!( + provider = %provider_id, + error = %store_error, + "failed to store provider inventory refresh error during session init" + ); + } else { + refresh_guard.complete(); } + warn!( + provider = %provider_id, + error = %error_message, + "provider inventory refresh failed during session init" + ); } } - Ok(_) => {} - Err(error) => warn!( - provider = %provider_id, - error = %error, - "failed to plan provider inventory refresh during session init" - ), - } - - if let Ok(Some(refreshed_inventory)) = self - .provider_inventory - .entry_for_provider(provider_name) - .await - { - inventory = refreshed_inventory; } } + Ok(_) => {} Err(error) => warn!( - provider = %provider_name, + provider = %provider_id, error = %error, - "failed to initialize provider during synchronous inventory refresh" + "failed to plan provider inventory refresh during session init" ), + }; + + if let Ok(Some(refreshed_inventory)) = self + .provider_inventory + .entry_for_provider(provider_name) + .await + { + inventory = refreshed_inventory; } } Err(error) => warn!( provider = %provider_name, error = %error, - "failed to load config during synchronous inventory refresh" + "failed to initialize provider during synchronous inventory refresh" ), } } @@ -1198,7 +1170,6 @@ impl GooseAcpAgent { let sessions = Arc::clone(&self.sessions); let session_manager = Arc::clone(&self.session_manager); let permission_manager = Arc::clone(&self.permission_manager); - let config_dir = self.config_dir.clone(); let builtins = self.builtins.clone(); let client_fs_capabilities = self .client_fs_capabilities @@ -1210,20 +1181,12 @@ impl GooseAcpAgent { let provider_factory = Arc::clone(&self.provider_factory); let disable_session_naming = self.disable_session_naming; let goose_platform = self.goose_platform.clone(); + let config = self.config.clone(); tokio::spawn(async move { let t_setup = std::time::Instant::now(); debug!(target: "perf", sid = %sid, "perf: agent_setup start (background)"); - // Shared config — read once, used by both phases. - let config = match Config::new(config_dir.join(CONFIG_YAML_NAME), "goose") { - Ok(c) => c, - Err(e) => { - let msg = e.to_string(); - error!(error = %msg, "Background agent setup failed (config)"); - let _ = agent_tx.send(Some(Err(msg))); - return; - } - }; + let config = config.as_ref(); // ── Phase 1: create agent + init provider (fast, ~55ms) ────── let phase1: Result, String> = async { @@ -1244,11 +1207,11 @@ impl GooseAcpAgent { // fall back to reading config (e.g. load_session path). let (provider_name, model_config) = match resolved_provider { Some(resolved) => resolved, - None => resolve_provider_and_model_from_config(&config, &goose_session).await?, + None => resolve_provider_and_model_from_config(config, &goose_session).await?, }; let ext_state = EnabledExtensionsState::extensions_or_default( Some(&goose_session.extension_data), - &config, + config, ); let provider = match prebuilt_provider { Some(provider) => provider, @@ -1289,7 +1252,7 @@ impl GooseAcpAgent { // ── Phase 2: load extensions (slow, may take seconds) ──────── let phase2: Result<(), String> = async { - let mut extensions = get_enabled_extensions_with_config(&config); + let mut extensions = get_enabled_extensions_with_config(config); extensions.extend(builtins.iter().map(|b| builtin_to_extension_config(b))); let acp_developer = if (client_fs_capabilities.read_text_file @@ -1975,7 +1938,8 @@ impl GooseAcpAgent { // Resolve provider + model from config so we can include the current // model in the response without waiting for the full agent setup. - let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await; + let resolved = + resolve_provider_and_model_from_config(self.config.as_ref(), &goose_session).await; let initial_usage_update = resolved .as_ref() .ok() @@ -2377,7 +2341,8 @@ impl GooseAcpAgent { let mode_state = build_mode_state(loaded_mode)?; - let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await; + let resolved = + resolve_provider_and_model_from_config(self.config.as_ref(), &goose_session).await; let initial_usage_update = resolved .as_ref() .ok() @@ -2598,7 +2563,7 @@ impl GooseAcpAgent { model_id: &str, ) -> Result { let internal_id = self.internal_session_id(thread_id).await?; - let config = self.config()?; + let config = self.config.as_ref(); let agent = self.get_session_agent_provider_ready(thread_id).await?; let current_provider = agent .provider() @@ -2606,7 +2571,7 @@ impl GooseAcpAgent { .internal_err_ctx("Failed to get provider")?; let provider_name = current_provider.get_name().to_string(); let extensions = - EnabledExtensionsState::for_session(&self.session_manager, &internal_id, &config).await; + EnabledExtensionsState::for_session(&self.session_manager, &internal_id, config).await; let model_config = crate::model::ModelConfig::new(model_id) .invalid_params_err_ctx("Invalid model config")? .with_canonical_limits(&provider_name); @@ -2732,7 +2697,7 @@ impl GooseAcpAgent { request_params: Option>, ) -> Result<(), sacp::Error> { let internal_id = self.internal_session_id(thread_id).await?; - let config = self.config()?; + let config = self.config.as_ref(); let agent = self.get_session_agent_provider_ready(thread_id).await?; let current_provider = agent .provider() @@ -2770,7 +2735,7 @@ impl GooseAcpAgent { .with_request_params(request_params); let extensions = - EnabledExtensionsState::for_session(&self.session_manager, &internal_id, &config).await; + EnabledExtensionsState::for_session(&self.session_manager, &internal_id, config).await; let new_provider = self .create_provider(&resolved_provider_name, model_config, extensions) .await @@ -2880,7 +2845,8 @@ impl GooseAcpAgent { .insert(new_thread_id.clone(), session); let mode_state = build_mode_state(self.goose_mode)?; - let resolved = resolve_provider_and_model(&self.config_dir, &goose_session).await; + let resolved = + resolve_provider_and_model_from_config(self.config.as_ref(), &goose_session).await; let (model_state, config_options, prebuilt_provider) = self .prepare_session_init_config(&resolved, &mode_state, &goose_session) .await; @@ -3440,7 +3406,7 @@ impl GooseAcpAgent { &self, req: ReadConfigRequest, ) -> Result { - let config = self.config()?; + let config = self.config.as_ref(); let response = match config.get_param::(&req.key) { Ok(value) => ReadConfigResponse { value }, Err(crate::config::ConfigError::NotFound(_)) => ReadConfigResponse { @@ -3456,7 +3422,7 @@ impl GooseAcpAgent { &self, req: UpsertConfigRequest, ) -> Result { - let config = self.config()?; + let config = self.config.as_ref(); config.set_param(&req.key, &req.value).internal_err()?; Ok(EmptyResponse {}) } @@ -3466,7 +3432,7 @@ impl GooseAcpAgent { &self, req: RemoveConfigRequest, ) -> Result { - let config = self.config()?; + let config = self.config.as_ref(); config.delete(&req.key).internal_err()?; Ok(EmptyResponse {}) } @@ -3476,7 +3442,7 @@ impl GooseAcpAgent { &self, req: CheckSecretRequest, ) -> Result { - let config = self.config()?; + let config = self.config.as_ref(); let exists = config.get_secret::(&req.key).is_ok(); Ok(CheckSecretResponse { exists }) } @@ -3486,7 +3452,7 @@ impl GooseAcpAgent { &self, req: UpsertSecretRequest, ) -> Result { - let config = self.config()?; + let config = self.config.as_ref(); config.set_secret(&req.key, &req.value).internal_err()?; Config::global().invalidate_secrets_cache(); Ok(EmptyResponse {}) @@ -3497,7 +3463,7 @@ impl GooseAcpAgent { &self, req: RemoveSecretRequest, ) -> Result { - let config = self.config()?; + let config = self.config.as_ref(); config.delete_secret(&req.key).internal_err()?; Config::global().invalidate_secrets_cache(); Ok(EmptyResponse {}) diff --git a/crates/goose/src/acp/server_factory.rs b/crates/goose/src/acp/server_factory.rs index 2e9686a708a9..3d9d2aa6e7de 100644 --- a/crates/goose/src/acp/server_factory.rs +++ b/crates/goose/src/acp/server_factory.rs @@ -21,11 +21,7 @@ impl AcpServer { } pub async fn create_agent(&self) -> Result> { - let config_path = self - .config - .config_dir - .join(crate::config::base::CONFIG_YAML_NAME); - let config = crate::config::Config::new(&config_path, "goose")?; + let config = crate::config::Config::for_config_dir(self.config.config_dir.clone())?; let goose_mode = config .get_goose_mode() diff --git a/crates/goose/src/config/base.rs b/crates/goose/src/config/base.rs index 5266c5dc6aeb..a6f081a3a18f 100644 --- a/crates/goose/src/config/base.rs +++ b/crates/goose/src/config/base.rs @@ -2,7 +2,7 @@ use crate::config::paths::Paths; use crate::config::GooseMode; use fs2::FileExt; use keyring::Entry; -use once_cell::sync::OnceCell; +use once_cell::sync::{Lazy, OnceCell}; use serde::{Deserialize, Serialize}; use serde_json::Value; use serde_yaml::Mapping; @@ -130,6 +130,29 @@ pub struct Config { secrets_cache: Arc>>>, } +#[derive(Clone)] +pub enum ConfigHandle { + Global, + Cached(Arc), +} + +impl AsRef for ConfigHandle { + fn as_ref(&self) -> &Config { + match self { + Self::Global => Config::global(), + Self::Cached(config) => config.as_ref(), + } + } +} + +impl std::ops::Deref for ConfigHandle { + type Target = Config; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + enum SecretStorage { Keyring { service: String }, File { path: PathBuf }, @@ -137,6 +160,8 @@ enum SecretStorage { // Global instance static GLOBAL_CONFIG: OnceCell = OnceCell::new(); +static CONFIG_CACHE: Lazy>>> = + Lazy::new(|| Mutex::new(HashMap::new())); fn system_config_path() -> PathBuf { #[cfg(unix)] @@ -163,43 +188,7 @@ fn bundled_defaults_path() -> Option { impl Default for Config { fn default() -> Self { - let config_dir = Paths::config_dir(); - let user_config_path = config_dir.join(CONFIG_YAML_NAME); - - let mut config_paths = vec![system_config_path()]; - if let Some(defaults) = bundled_defaults_path() { - config_paths.insert(0, defaults); - } - config_paths.push(user_config_path.clone()); - - let no_secrets_config = Self { - config_paths: config_paths.clone(), - secrets: SecretStorage::File { - path: Default::default(), - }, - guard: Mutex::new(()), - secrets_cache: Arc::new(Mutex::new(None)), - }; - - let secrets = if env::var("GOOSE_DISABLE_KEYRING").is_ok() - || no_secrets_config - .get_param::("GOOSE_DISABLE_KEYRING") - .is_ok_and(|v| keyring_disabled_value(&v)) - { - SecretStorage::File { - path: config_dir.join("secrets.yaml"), - } - } else { - SecretStorage::Keyring { - service: KEYRING_SERVICE.to_string(), - } - }; - Self { - config_paths, - secrets, - guard: Mutex::new(()), - secrets_cache: Arc::new(Mutex::new(None)), - } + Self::with_config_dir(Paths::config_dir()) } } @@ -343,6 +332,45 @@ fn keyring_disabled_in_config(config_path: &Path) -> bool { .unwrap_or(false) } +fn normalize_writable_config_path(path: &Path) -> Result { + if path.exists() { + return Ok(path.canonicalize()?); + } + + if let Some(parent) = path + .parent() + .filter(|parent| !parent.as_os_str().is_empty()) + { + if parent.exists() { + let file_name = path + .file_name() + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from(CONFIG_YAML_NAME)); + return Ok(parent.canonicalize()?.join(file_name)); + } + } + + let absolute = if path.is_absolute() { + path.to_path_buf() + } else { + env::current_dir()?.join(path) + }; + + let mut normalized = PathBuf::new(); + for component in absolute.components() { + match component { + std::path::Component::CurDir => {} + std::path::Component::ParentDir => { + normalized.pop(); + } + std::path::Component::Normal(_) + | std::path::Component::RootDir + | std::path::Component::Prefix(_) => normalized.push(component.as_os_str()), + } + } + Ok(normalized) +} + impl Config { /// Get the global configuration instance. /// @@ -352,6 +380,69 @@ impl Config { GLOBAL_CONFIG.get_or_init(Config::default) } + pub fn for_config_dir(config_dir: PathBuf) -> Result { + let config_path = config_dir.join(CONFIG_YAML_NAME); + let cache_key = normalize_writable_config_path(&config_path)?; + let default_key = + normalize_writable_config_path(&Paths::config_dir().join(CONFIG_YAML_NAME))?; + + if cache_key == default_key { + return Ok(ConfigHandle::Global); + } + + let mut cache = CONFIG_CACHE.lock().unwrap(); + if let Some(config) = cache.get(&cache_key) { + return Ok(ConfigHandle::Cached(Arc::clone(config))); + } + + let config = Arc::new(Self::new(config_path, KEYRING_SERVICE)?); + cache.insert(cache_key, Arc::clone(&config)); + Ok(ConfigHandle::Cached(config)) + } + + fn config_paths_for_dir(config_dir: &Path) -> Vec { + let mut config_paths = vec![system_config_path()]; + if let Some(defaults) = bundled_defaults_path() { + config_paths.insert(0, defaults); + } + config_paths.push(config_dir.join(CONFIG_YAML_NAME)); + config_paths + } + + fn with_config_dir(config_dir: PathBuf) -> Self { + let config_paths = Self::config_paths_for_dir(&config_dir); + + let no_secrets_config = Self { + config_paths: config_paths.clone(), + secrets: SecretStorage::File { + path: Default::default(), + }, + guard: Mutex::new(()), + secrets_cache: Arc::new(Mutex::new(None)), + }; + + let secrets = if env::var("GOOSE_DISABLE_KEYRING").is_ok() + || no_secrets_config + .get_param::("GOOSE_DISABLE_KEYRING") + .is_ok_and(|v| keyring_disabled_value(&v)) + { + SecretStorage::File { + path: config_dir.join("secrets.yaml"), + } + } else { + SecretStorage::Keyring { + service: KEYRING_SERVICE.to_string(), + } + }; + + Self { + config_paths, + secrets, + guard: Mutex::new(()), + secrets_cache: Arc::new(Mutex::new(None)), + } + } + /// Create a new configuration instance with custom paths /// /// This is primarily useful for testing or for applications that need diff --git a/crates/goose/src/config/mod.rs b/crates/goose/src/config/mod.rs index cd731c2ae3e2..667f304bf232 100644 --- a/crates/goose/src/config/mod.rs +++ b/crates/goose/src/config/mod.rs @@ -12,7 +12,7 @@ pub mod signup_openrouter; pub mod signup_tetrate; pub use crate::agents::ExtensionConfig; -pub use base::{merge_config_values, Config, ConfigError}; +pub use base::{merge_config_values, Config, ConfigError, ConfigHandle}; pub use declarative_providers::DeclarativeProviderConfig; pub use experiments::ExperimentManager; pub use extensions::{ diff --git a/crates/goose/tests/acp_common_tests/mod.rs b/crates/goose/tests/acp_common_tests/mod.rs index 34272c69f6ab..be26add78f3c 100644 --- a/crates/goose/tests/acp_common_tests/mod.rs +++ b/crates/goose/tests/acp_common_tests/mod.rs @@ -948,18 +948,19 @@ pub async fn run_permission_persistence() { }; let mut conn = C::new(config, openai).await; + let permission_path = conn.permission_config_path(); let SessionData { mut session, .. } = conn.new_session().await.unwrap(); expected_session_id.set(&session.session_id().0); for (decision, expected_status, expected_yaml) in cases { conn.reset_openai(); conn.reset_permissions(); - let _ = fs::remove_file(temp_dir.path().join("permission.yaml")); + let _ = fs::remove_file(&permission_path); let output = session.prompt(prompt, decision).await.unwrap(); assert_eq!(output.tool_status.unwrap(), expected_status); assert_eq!( - fs::read_to_string(temp_dir.path().join("permission.yaml")).unwrap_or_default(), + fs::read_to_string(&permission_path).unwrap_or_default(), expected_yaml, ); } diff --git a/crates/goose/tests/acp_custom_requests_test.rs b/crates/goose/tests/acp_custom_requests_test.rs index a2a5951647ff..8c500ff6fd1d 100644 --- a/crates/goose/tests/acp_custom_requests_test.rs +++ b/crates/goose/tests/acp_custom_requests_test.rs @@ -15,6 +15,7 @@ use goose_test_support::{EnforceSessionId, IgnoreSessionId}; use std::sync::{Arc, Mutex}; use common_tests::fixtures::OpenAiFixture; +use goose::config::Config; struct MockProvider { name: String, @@ -142,8 +143,16 @@ fn test_custom_provider_inventory_includes_metadata() { #[test] fn test_custom_config_crud() { run_test(async { + let temp_dir = tempfile::tempdir().unwrap(); let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; - let conn = AcpServerConnection::new(TestConnectionConfig::default(), openai).await; + let conn = AcpServerConnection::new( + TestConnectionConfig { + data_root: temp_dir.path().to_path_buf(), + ..Default::default() + }, + openai, + ) + .await; send_custom( conn.cx(), @@ -156,6 +165,12 @@ fn test_custom_config_crud() { .await .expect("config upsert should succeed"); + let cached_config = Config::for_config_dir(temp_dir.path().to_path_buf()).unwrap(); + assert_eq!( + cached_config.get_param::("GOOSE_PROVIDER").unwrap(), + "anthropic" + ); + let response = send_custom( conn.cx(), "_goose/config/read", @@ -190,6 +205,44 @@ fn test_custom_config_crud() { }); } +#[test] +fn test_concurrent_custom_config_upserts_preserve_all_keys() { + run_test(async { + let temp_dir = tempfile::tempdir().unwrap(); + let openai = OpenAiFixture::new(vec![], Arc::new(EnforceSessionId::default())).await; + let conn = AcpServerConnection::new( + TestConnectionConfig { + data_root: temp_dir.path().to_path_buf(), + ..Default::default() + }, + openai, + ) + .await; + + let requests = (0..24).map(|i| { + let key = format!("CONCURRENT_CONFIG_KEY_{i}"); + send_custom( + conn.cx(), + "_goose/config/upsert", + serde_json::json!({ + "key": key, + "value": i, + }), + ) + }); + + for result in futures::future::join_all(requests).await { + result.expect("config upsert should succeed"); + } + + let cached_config = Config::for_config_dir(temp_dir.path().to_path_buf()).unwrap(); + for i in 0..24 { + let key = format!("CONCURRENT_CONFIG_KEY_{i}"); + assert_eq!(cached_config.get_param::(&key).unwrap(), i); + } + }); +} + #[test] fn test_provider_switching_updates_session_state() { run_test(async { diff --git a/crates/goose/tests/acp_fixtures/mod.rs b/crates/goose/tests/acp_fixtures/mod.rs index 082ae8166459..16ace8e8a6c5 100644 --- a/crates/goose/tests/acp_fixtures/mod.rs +++ b/crates/goose/tests/acp_fixtures/mod.rs @@ -23,7 +23,7 @@ use sacp::schema::{ }; use std::collections::VecDeque; use std::future::Future; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use tokio::task::JoinHandle; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; @@ -130,6 +130,29 @@ pub type DuplexTransport = sacp::ByteStreams< tokio_util::compat::Compat, >; +static ACP_TEST_LOCK: Mutex<()> = Mutex::new(()); + +fn prepare_acp_config_dir(data_root: &Path, config_dir: &Path, current_model: &str) { + fs::create_dir_all(config_dir).unwrap(); + + let source_config = data_root.join(goose::config::base::CONFIG_YAML_NAME); + let target_config = config_dir.join(goose::config::base::CONFIG_YAML_NAME); + + if source_config.exists() { + if source_config != target_config { + fs::copy(source_config, target_config).unwrap(); + } + } else { + fs::write( + target_config, + format!("GOOSE_MODEL: {current_model}\nGOOSE_PROVIDER: openai\n"), + ) + .unwrap(); + } + + let _ = fs::remove_file(config_dir.join("permission.yaml")); +} + /// Wires up duplex streams, spawns `serve` for the given agent, and returns /// a ready-to-use sacp transport plus the server handle. #[allow(dead_code)] @@ -161,14 +184,8 @@ pub async fn spawn_acp_server_in_process( fs::create_dir_all(data_root).unwrap(); // TODO: Paths::in_state_dir is global, ignoring per-test data_root fs::create_dir_all(Paths::in_state_dir("logs")).unwrap(); - let config_path = data_root.join(goose::config::base::CONFIG_YAML_NAME); - if !config_path.exists() { - fs::write( - &config_path, - format!("GOOSE_MODEL: {current_model}\nGOOSE_PROVIDER: openai\n"), - ) - .unwrap(); - } + let config_dir = data_root.to_path_buf(); + prepare_acp_config_dir(data_root, &config_dir, current_model); let provider_factory = provider_factory.unwrap_or_else(|| { let base_url = openai_base_url.to_string(); Arc::new(move |_provider_name, model_config, _extensions| { @@ -188,7 +205,7 @@ pub async fn spawn_acp_server_in_process( provider_factory, builtins.to_vec(), data_root.to_path_buf(), - data_root.to_path_buf(), + config_dir, goose_mode, true, GoosePlatform::GooseCli, @@ -526,6 +543,7 @@ pub trait Connection: Sized { value: &str, ) -> anyhow::Result<()>; fn data_root(&self) -> std::path::PathBuf; + fn permission_config_path(&self) -> std::path::PathBuf; fn reset_openai(&self); fn reset_permissions(&self); } @@ -556,6 +574,8 @@ where { register_builtin_extensions(goose_mcp::BUILTIN_EXTENSIONS.clone()); + let _guard = ACP_TEST_LOCK.lock().unwrap_or_else(|err| err.into_inner()); + let handle = std::thread::Builder::new() .name("acp-test".to_string()) .stack_size(8 * 1024 * 1024) diff --git a/crates/goose/tests/acp_fixtures/provider.rs b/crates/goose/tests/acp_fixtures/provider.rs index 6f5b658aae78..113bccba663e 100644 --- a/crates/goose/tests/acp_fixtures/provider.rs +++ b/crates/goose/tests/acp_fixtures/provider.rs @@ -287,6 +287,10 @@ impl Connection for AcpProviderConnection { self.data_root.clone() } + fn permission_config_path(&self) -> std::path::PathBuf { + self.permission_manager.get_config_path().to_path_buf() + } + async fn set_mode(&self, _session_id: &str, _mode_id: &str) -> anyhow::Result<()> { Err(anyhow::anyhow!("not implemented for AcpProviderConnection")) } diff --git a/crates/goose/tests/acp_fixtures/server.rs b/crates/goose/tests/acp_fixtures/server.rs index f394a03d1c5b..570291ce7436 100644 --- a/crates/goose/tests/acp_fixtures/server.rs +++ b/crates/goose/tests/acp_fixtures/server.rs @@ -439,6 +439,10 @@ impl Connection for AcpServerConnection { self.data_root.clone() } + fn permission_config_path(&self) -> std::path::PathBuf { + self.permission_manager.get_config_path().to_path_buf() + } + fn reset_openai(&self) { self._openai.reset(); } diff --git a/crates/goose/tests/acp_global_config_test.rs b/crates/goose/tests/acp_global_config_test.rs new file mode 100644 index 000000000000..f458d3a9fc1c --- /dev/null +++ b/crates/goose/tests/acp_global_config_test.rs @@ -0,0 +1,127 @@ +use goose::config::base::CONFIG_YAML_NAME; +use goose::config::paths::Paths; +use goose::config::{Config, ConfigHandle}; +use std::sync::Arc; + +#[test] +fn default_config_dir_returns_global_handle() { + let handle = Config::for_config_dir(Paths::config_dir()).unwrap(); + + assert!(matches!(handle, ConfigHandle::Global)); +} + +#[test] +fn same_custom_config_dir_returns_same_cached_config() { + let dir = tempfile::tempdir().unwrap(); + + let first = Config::for_config_dir(dir.path().to_path_buf()).unwrap(); + let second = Config::for_config_dir(dir.path().to_path_buf()).unwrap(); + + match (first, second) { + (ConfigHandle::Cached(first), ConfigHandle::Cached(second)) => { + assert!(Arc::ptr_eq(&first, &second)); + } + _ => panic!("expected cached config handles"), + } +} + +#[test] +fn equivalent_missing_custom_config_dirs_share_cached_config() { + let root = tempfile::tempdir().unwrap(); + let direct = root.path().join("missing"); + let with_parent = root.path().join("missing").join("..").join("missing"); + + let first = Config::for_config_dir(direct).unwrap(); + let second = Config::for_config_dir(with_parent).unwrap(); + + match (first, second) { + (ConfigHandle::Cached(first), ConfigHandle::Cached(second)) => { + assert!(Arc::ptr_eq(&first, &second)); + } + _ => panic!("expected cached config handles"), + } +} + +#[test] +fn custom_config_dirs_can_coexist() { + let first_dir = tempfile::tempdir().unwrap(); + let second_dir = tempfile::tempdir().unwrap(); + + let first = Config::for_config_dir(first_dir.path().to_path_buf()).unwrap(); + let second = Config::for_config_dir(second_dir.path().to_path_buf()).unwrap(); + + first.set_param("CUSTOM_KEY", "first").unwrap(); + second.set_param("CUSTOM_KEY", "second").unwrap(); + + assert_eq!( + first.get_param::("CUSTOM_KEY").unwrap(), + "first".to_string() + ); + assert_eq!( + second.get_param::("CUSTOM_KEY").unwrap(), + "second".to_string() + ); + assert_eq!( + first.path(), + first_dir + .path() + .join(CONFIG_YAML_NAME) + .display() + .to_string() + ); + assert_eq!( + second.path(), + second_dir + .path() + .join(CONFIG_YAML_NAME) + .display() + .to_string() + ); +} + +#[test] +fn cached_handles_for_same_custom_path_share_written_state() { + let dir = tempfile::tempdir().unwrap(); + + let first = Config::for_config_dir(dir.path().to_path_buf()).unwrap(); + let second = Config::for_config_dir(dir.path().to_path_buf()).unwrap(); + + first.set_param("FIRST_KEY", "first").unwrap(); + second.set_param("SECOND_KEY", "second").unwrap(); + + assert_eq!( + first.get_param::("SECOND_KEY").unwrap(), + "second".to_string() + ); + assert_eq!( + second.get_param::("FIRST_KEY").unwrap(), + "first".to_string() + ); +} + +#[test] +fn concurrent_cached_config_writes_to_same_custom_path_do_not_lose_updates() { + let dir = tempfile::tempdir().unwrap(); + let dir_path = dir.path().to_path_buf(); + + let handles = (0..24) + .map(|i| { + let dir_path = dir_path.clone(); + std::thread::spawn(move || { + let config = Config::for_config_dir(dir_path).unwrap(); + let key = format!("CONCURRENT_KEY_{i}"); + config.set_param(&key, i).unwrap(); + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap(); + } + + let config = Config::for_config_dir(dir.path().to_path_buf()).unwrap(); + for i in 0..24 { + let key = format!("CONCURRENT_KEY_{i}"); + assert_eq!(config.get_param::(&key).unwrap(), i); + } +}