diff --git a/src/config.rs b/src/config.rs index 24487ce3..5d727316 100644 --- a/src/config.rs +++ b/src/config.rs @@ -259,16 +259,25 @@ impl Config { ); } - // Try Backend.AI environment first + // Check for Backend.AI environment first if let Some(backendai_cluster) = Self::from_backendai_env() { tracing::debug!("Using Backend.AI cluster configuration from environment"); let mut config = Self::default(); config .clusters - .insert("backendai".to_string(), backendai_cluster); + .insert("bai_auto".to_string(), backendai_cluster); return Ok(config); } + // Load configuration from standard locations + Self::load_from_standard_locations().await.or_else(|_| { + tracing::debug!("No config file found, using default empty configuration"); + Ok(Self::default()) + }) + } + + /// Load configuration from standard locations (helper method) + async fn load_from_standard_locations() -> Result { // Try current directory config.yaml let current_dir_config = PathBuf::from("config.yaml"); if current_dir_config.exists() { @@ -308,9 +317,8 @@ impl Config { } } - // Finally, try the default path (will create empty config if needed) - tracing::debug!("No config file found, using default empty configuration"); - Ok(Self::default()) + // No config file found + anyhow::bail!("No configuration file found") } pub fn get_cluster(&self, name: &str) -> Option<&Cluster> { @@ -585,7 +593,7 @@ pub struct InteractiveConfigUpdate { pub colors: Option>, } -fn expand_tilde(path: &Path) -> PathBuf { +pub fn expand_tilde(path: &Path) -> PathBuf { if let Some(path_str) = path.to_str() { if path_str.starts_with("~/") { if let Ok(home) = std::env::var("HOME") { diff --git a/src/main.rs b/src/main.rs index a618d401..956358bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ use anyhow::Result; use clap::Parser; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::time::Duration; use bssh::{ @@ -98,7 +98,7 @@ async fn main() -> Result<()> { } // Determine nodes to execute on - let nodes = resolve_nodes(&cli, &config).await?; + let (nodes, actual_cluster_name) = resolve_nodes(&cli, &config).await?; if nodes.is_empty() { anyhow::bail!( @@ -124,10 +124,19 @@ async fn main() -> Result<()> { // Handle remaining commands match cli.command { Some(Commands::Ping) => { + // Determine SSH key path: CLI argument takes precedence over config + let key_path = if let Some(identity) = &cli.identity { + Some(identity.clone()) + } else { + config + .get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref())) + .map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key))) + }; + ping_nodes( nodes, cli.parallel, - cli.identity.as_deref(), + key_path.as_deref(), strict_mode, cli.use_agent, cli.password, @@ -139,10 +148,19 @@ async fn main() -> Result<()> { destination, recursive, }) => { + // Determine SSH key path: CLI argument takes precedence over config + let key_path = if let Some(identity) = &cli.identity { + Some(identity.clone()) + } else { + config + .get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref())) + .map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key))) + }; + let params = FileTransferParams { nodes, max_parallel: cli.parallel, - key_path: cli.identity.as_deref(), + key_path: key_path.as_deref(), strict_mode, use_agent: cli.use_agent, use_password: cli.password, @@ -155,10 +173,19 @@ async fn main() -> Result<()> { destination, recursive, }) => { + // Determine SSH key path: CLI argument takes precedence over config + let key_path = if let Some(identity) = &cli.identity { + Some(identity.clone()) + } else { + config + .get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref())) + .map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key))) + }; + let params = FileTransferParams { nodes, max_parallel: cli.parallel, - key_path: cli.identity.as_deref(), + key_path: key_path.as_deref(), strict_mode, use_agent: cli.use_agent, use_password: cli.password, @@ -238,14 +265,23 @@ async fn main() -> Result<()> { let timeout = if cli.timeout > 0 { Some(cli.timeout) } else { - config.get_timeout(cli.cluster.as_deref()) + config.get_timeout(actual_cluster_name.as_deref().or(cli.cluster.as_deref())) + }; + + // Determine SSH key path: CLI argument takes precedence over config + let key_path = if let Some(identity) = &cli.identity { + Some(identity.clone()) + } else { + config + .get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref())) + .map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key))) }; let params = ExecuteCommandParams { nodes, command: &command, max_parallel: cli.parallel, - key_path: cli.identity.as_deref(), + key_path: key_path.as_deref(), verbose: cli.verbose > 0, strict_mode, use_agent: cli.use_agent, @@ -258,8 +294,9 @@ async fn main() -> Result<()> { } } -async fn resolve_nodes(cli: &Cli, config: &Config) -> Result> { +async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<(Vec, Option)> { let mut nodes = Vec::new(); + let mut cluster_name = None; if let Some(hosts) = &cli.hosts { // Parse hosts from CLI @@ -270,16 +307,18 @@ async fn resolve_nodes(cli: &Cli, config: &Config) -> Result> { nodes.push(node); } } - } else if let Some(cluster_name) = &cli.cluster { + } else if let Some(cli_cluster_name) = &cli.cluster { // Get nodes from cluster configuration - nodes = config.resolve_nodes(cluster_name)?; + nodes = config.resolve_nodes(cli_cluster_name)?; + cluster_name = Some(cli_cluster_name.clone()); } else { // Check if Backend.AI environment is detected (automatic cluster) - if config.clusters.contains_key("backendai") { + if config.clusters.contains_key("bai_auto") { // Automatically use Backend.AI cluster when no explicit cluster is specified - nodes = config.resolve_nodes("backendai")?; + nodes = config.resolve_nodes("bai_auto")?; + cluster_name = Some("bai_auto".to_string()); } } - Ok(nodes) + Ok((nodes, cluster_name)) } diff --git a/tests/backendai_env_test.rs b/tests/backendai_env_test.rs index ca573f54..1be355ac 100644 --- a/tests/backendai_env_test.rs +++ b/tests/backendai_env_test.rs @@ -45,19 +45,19 @@ async fn test_backendai_env_auto_detection() { .await .expect("Config should load with Backend.AI env"); - // Check that backendai cluster was created - assert!(config.clusters.contains_key("backendai")); + // Check that bai_auto cluster was created + assert!(config.clusters.contains_key("bai_auto")); - // Get the backendai cluster - let cluster = config.clusters.get("backendai").unwrap(); + // Get the bai_auto cluster + let cluster = config.clusters.get("bai_auto").unwrap(); // Verify nodes were parsed correctly assert_eq!(cluster.nodes.len(), 3); - // Resolve nodes for the backendai cluster + // Resolve nodes for the bai_auto cluster let nodes = config - .resolve_nodes("backendai") - .expect("Should resolve backendai nodes"); + .resolve_nodes("bai_auto") + .expect("Should resolve bai_auto nodes"); assert_eq!(nodes.len(), 3); // Check node details @@ -114,11 +114,11 @@ async fn test_backendai_env_with_single_host() { .await .expect("Config should load"); - // Verify backendai cluster exists - assert!(config.clusters.contains_key("backendai")); + // Verify bai_auto cluster exists + assert!(config.clusters.contains_key("bai_auto")); let nodes = config - .resolve_nodes("backendai") + .resolve_nodes("bai_auto") .expect("Should resolve nodes"); assert_eq!(nodes.len(), 1); assert_eq!(nodes[0].host, "single-node.ai");