diff --git a/examples/interactive_demo.rs b/examples/interactive_demo.rs index c8cd797a..611bc5f1 100644 --- a/examples/interactive_demo.rs +++ b/examples/interactive_demo.rs @@ -17,6 +17,7 @@ use bssh::commands::interactive::InteractiveCommand; use bssh::config::{Config, InteractiveConfig}; use bssh::node::Node; +use bssh::ssh::known_hosts::StrictHostKeyChecking; use std::path::PathBuf; #[tokio::main] @@ -48,6 +49,10 @@ async fn main() -> anyhow::Result<()> { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; println!("Starting interactive session..."); diff --git a/src/commands/interactive.rs b/src/commands/interactive.rs index 67751cbe..4eb55b88 100644 --- a/src/commands/interactive.rs +++ b/src/commands/interactive.rs @@ -22,7 +22,7 @@ use rustyline::config::Configurer; use rustyline::error::ReadlineError; use rustyline::DefaultEditor; use std::io::{self, Write}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use tokio::sync::mpsc; @@ -52,6 +52,11 @@ pub struct InteractiveCommand { pub config: Config, pub interactive_config: InteractiveConfig, pub cluster_name: Option, + // Authentication parameters (consistent with exec mode) + pub key_path: Option, + pub use_agent: bool, + pub use_password: bool, + pub strict_mode: StrictHostKeyChecking, } /// Result of an interactive session @@ -194,11 +199,11 @@ impl InteractiveCommand { /// Connect to a single node and establish an interactive shell async fn connect_to_node(&self, node: Node) -> Result { - // Determine authentication method + // Determine authentication method using the same logic as exec mode let auth_method = self.determine_auth_method(&node)?; - // Set up host key checking - let check_method = get_check_method(StrictHostKeyChecking::AcceptNew); + // Set up host key checking using the configured strict mode + let check_method = get_check_method(self.strict_mode); // Connect with timeout let addr = (node.host.as_str(), node.port); @@ -252,29 +257,122 @@ impl InteractiveCommand { }) } - /// Determine authentication method based on node and config + /// Determine authentication method based on node and config (same logic as exec mode) fn determine_auth_method(&self, node: &Node) -> Result { - // Check if SSH agent is available - if std::env::var("SSH_AUTH_SOCK").is_ok() { + // If password authentication is explicitly requested + if self.use_password { + tracing::debug!("Using password authentication"); + let password = rpassword::prompt_password(format!( + "Enter password for {}@{}: ", + node.username, node.host + )) + .with_context(|| "Failed to read password")?; + return Ok(AuthMethod::with_password(&password)); + } + + // If SSH agent is explicitly requested, try that first + if self.use_agent { + #[cfg(not(target_os = "windows"))] + { + // Check if SSH_AUTH_SOCK is available + if std::env::var("SSH_AUTH_SOCK").is_ok() { + tracing::debug!("Using SSH agent for authentication"); + return Ok(AuthMethod::Agent); + } + tracing::warn!( + "SSH agent requested but SSH_AUTH_SOCK environment variable not set" + ); + // Fall through to key file authentication + } + #[cfg(target_os = "windows")] + { + anyhow::bail!("SSH agent authentication is not supported on Windows"); + } + } + + // Try key file authentication + if let Some(ref key_path) = self.key_path { + tracing::debug!("Authenticating with key: {:?}", key_path); + + // Check if the key is encrypted by attempting to read it + let key_contents = std::fs::read_to_string(key_path) + .with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?; + + let passphrase = if key_contents.contains("ENCRYPTED") + || key_contents.contains("Proc-Type: 4,ENCRYPTED") + { + tracing::debug!("Detected encrypted SSH key, prompting for passphrase"); + let pass = + rpassword::prompt_password(format!("Enter passphrase for key {key_path:?}: ")) + .with_context(|| "Failed to read passphrase")?; + Some(pass) + } else { + None + }; + + return Ok(AuthMethod::with_key_file(key_path, passphrase.as_deref())); + } + + // If no explicit key path, try SSH agent if available (auto-detect) + #[cfg(not(target_os = "windows"))] + if !self.use_agent && std::env::var("SSH_AUTH_SOCK").is_ok() { + tracing::debug!("SSH agent detected, attempting agent authentication"); return Ok(AuthMethod::Agent); } - // Try to find SSH key - let ssh_key_paths = vec![ - dirs::home_dir().map(|h| h.join(".ssh/id_rsa")), - dirs::home_dir().map(|h| h.join(".ssh/id_ed25519")), + // Fallback to default key locations + let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string()); + let home_path = Path::new(&home).join(".ssh"); + + // Try common key files in order of preference + let default_keys = [ + home_path.join("id_ed25519"), + home_path.join("id_rsa"), + home_path.join("id_ecdsa"), + home_path.join("id_dsa"), ]; - for key_path in ssh_key_paths.into_iter().flatten() { - if key_path.exists() { - return Ok(AuthMethod::with_key_file(key_path, None)); + for default_key in &default_keys { + if default_key.exists() { + tracing::debug!("Using default key: {:?}", default_key); + + // Check if the key is encrypted + let key_contents = std::fs::read_to_string(default_key) + .with_context(|| format!("Failed to read SSH key file: {default_key:?}"))?; + + let passphrase = if key_contents.contains("ENCRYPTED") + || key_contents.contains("Proc-Type: 4,ENCRYPTED") + { + tracing::debug!("Detected encrypted SSH key, prompting for passphrase"); + let pass = rpassword::prompt_password(format!( + "Enter passphrase for key {default_key:?}: " + )) + .with_context(|| "Failed to read passphrase")?; + Some(pass) + } else { + None + }; + + return Ok(AuthMethod::with_key_file( + default_key, + passphrase.as_deref(), + )); } } - // If no key found, prompt for password - let password = - rpassword::prompt_password(format!("Password for {}@{}: ", node.username, node.host))?; - Ok(AuthMethod::with_password(&password)) + anyhow::bail!( + "SSH authentication failed: No authentication method available.\n\ + Tried:\n\ + - SSH agent: {}\n\ + - SSH keys: {:?}\n\ + Please ensure you have a valid SSH key or SSH agent running.", + if std::env::var("SSH_AUTH_SOCK").is_ok() { + "Available but no identities" + } else { + "Not available (SSH_AUTH_SOCK not set)" + }, + default_keys + ) } /// Run interactive mode with a single node @@ -832,6 +930,10 @@ mod tests { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; let path = PathBuf::from("~/test/file.txt"); @@ -856,6 +958,10 @@ mod tests { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; let node = Node::new(String::from("example.com"), 22, String::from("alice")); diff --git a/src/main.rs b/src/main.rs index 956358bd..3ffdb626 100644 --- a/src/main.rs +++ b/src/main.rs @@ -241,6 +241,15 @@ async fn main() -> Result<()> { let merged_work_dir = work_dir.or(interactive_config.work_dir.clone()); + // Determine SSH key path: CLI argument takes precedence over config + let key_path = if let Some(identity) = &cli.identity { + Some(identity.clone()) + } else { + config + .get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref())) + .map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key))) + }; + let interactive_cmd = InteractiveCommand { single_node: merged_mode.0, multiplex: merged_mode.1, @@ -251,6 +260,10 @@ async fn main() -> Result<()> { config: config.clone(), interactive_config, cluster_name: cluster_name.map(String::from), + key_path, + use_agent: cli.use_agent, + use_password: cli.password, + strict_mode, }; let result = interactive_cmd.execute().await?; println!("\nInteractive session ended."); diff --git a/tests/interactive_integration_test.rs b/tests/interactive_integration_test.rs index f0490154..d16213e7 100644 --- a/tests/interactive_integration_test.rs +++ b/tests/interactive_integration_test.rs @@ -17,6 +17,7 @@ use bssh::commands::interactive::InteractiveCommand; use bssh::config::{Config, InteractiveConfig}; use bssh::node::Node; +use bssh::ssh::known_hosts::StrictHostKeyChecking; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; @@ -41,6 +42,10 @@ fn test_interactive_command_builder() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert!(!cmd.single_node); @@ -66,6 +71,10 @@ fn test_history_file_handling() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert_eq!(cmd.history_file, history_path); @@ -154,6 +163,10 @@ async fn test_interactive_with_unreachable_nodes() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; // This should fail to connect @@ -179,6 +192,10 @@ async fn test_interactive_with_no_nodes() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; let result = cmd.execute().await; @@ -214,6 +231,10 @@ fn test_mode_configuration() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert!(single_cmd.single_node); @@ -230,6 +251,10 @@ fn test_mode_configuration() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert!(!multi_cmd.single_node); @@ -249,6 +274,10 @@ fn test_working_directory_config() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert_eq!(cmd_with_dir.work_dir, Some("/var/www".to_string())); @@ -263,6 +292,10 @@ fn test_working_directory_config() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert_eq!(cmd_without_dir.work_dir, None); @@ -289,6 +322,10 @@ fn test_prompt_format() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert_eq!(cmd.prompt_format, format); diff --git a/tests/interactive_test.rs b/tests/interactive_test.rs index 9d456d0b..7465d86d 100644 --- a/tests/interactive_test.rs +++ b/tests/interactive_test.rs @@ -15,6 +15,7 @@ use bssh::commands::interactive::InteractiveCommand; use bssh::config::{Config, InteractiveConfig}; use bssh::node::Node; +use bssh::ssh::known_hosts::StrictHostKeyChecking; use std::path::PathBuf; #[tokio::test] @@ -29,6 +30,10 @@ async fn test_interactive_command_creation() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; assert!(!cmd.single_node); @@ -48,6 +53,10 @@ async fn test_interactive_with_no_nodes() { config: Config::default(), interactive_config: InteractiveConfig::default(), cluster_name: None, + key_path: None, + use_agent: false, + use_password: false, + strict_mode: StrictHostKeyChecking::AcceptNew, }; let result = cmd.execute().await;