Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/interactive_demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use bssh::commands::interactive::InteractiveCommand;
use bssh::config::{Config, InteractiveConfig};
use bssh::node::Node;
use bssh::ssh::known_hosts::StrictHostKeyChecking;
use std::path::PathBuf;

#[tokio::main]
Expand Down Expand Up @@ -48,6 +49,10 @@ async fn main() -> anyhow::Result<()> {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

println!("Starting interactive session...");
Expand Down
142 changes: 124 additions & 18 deletions src/commands/interactive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use rustyline::config::Configurer;
use rustyline::error::ReadlineError;
use rustyline::DefaultEditor;
use std::io::{self, Write};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -52,6 +52,11 @@ pub struct InteractiveCommand {
pub config: Config,
pub interactive_config: InteractiveConfig,
pub cluster_name: Option<String>,
// Authentication parameters (consistent with exec mode)
pub key_path: Option<PathBuf>,
pub use_agent: bool,
pub use_password: bool,
pub strict_mode: StrictHostKeyChecking,
}

/// Result of an interactive session
Expand Down Expand Up @@ -194,11 +199,11 @@ impl InteractiveCommand {

/// Connect to a single node and establish an interactive shell
async fn connect_to_node(&self, node: Node) -> Result<NodeSession> {
// Determine authentication method
// Determine authentication method using the same logic as exec mode
let auth_method = self.determine_auth_method(&node)?;

// Set up host key checking
let check_method = get_check_method(StrictHostKeyChecking::AcceptNew);
// Set up host key checking using the configured strict mode
let check_method = get_check_method(self.strict_mode);

// Connect with timeout
let addr = (node.host.as_str(), node.port);
Expand Down Expand Up @@ -252,29 +257,122 @@ impl InteractiveCommand {
})
}

/// Determine authentication method based on node and config
/// Determine authentication method based on node and config (same logic as exec mode)
fn determine_auth_method(&self, node: &Node) -> Result<AuthMethod> {
// Check if SSH agent is available
if std::env::var("SSH_AUTH_SOCK").is_ok() {
// If password authentication is explicitly requested
if self.use_password {
tracing::debug!("Using password authentication");
let password = rpassword::prompt_password(format!(
"Enter password for {}@{}: ",
node.username, node.host
))
.with_context(|| "Failed to read password")?;
return Ok(AuthMethod::with_password(&password));
}

// If SSH agent is explicitly requested, try that first
if self.use_agent {
#[cfg(not(target_os = "windows"))]
{
// Check if SSH_AUTH_SOCK is available
if std::env::var("SSH_AUTH_SOCK").is_ok() {
tracing::debug!("Using SSH agent for authentication");
return Ok(AuthMethod::Agent);
}
tracing::warn!(
"SSH agent requested but SSH_AUTH_SOCK environment variable not set"
);
// Fall through to key file authentication
}
#[cfg(target_os = "windows")]
{
anyhow::bail!("SSH agent authentication is not supported on Windows");
}
}

// Try key file authentication
if let Some(ref key_path) = self.key_path {
tracing::debug!("Authenticating with key: {:?}", key_path);

// Check if the key is encrypted by attempting to read it
let key_contents = std::fs::read_to_string(key_path)
.with_context(|| format!("Failed to read SSH key file: {key_path:?}"))?;

let passphrase = if key_contents.contains("ENCRYPTED")
|| key_contents.contains("Proc-Type: 4,ENCRYPTED")
{
tracing::debug!("Detected encrypted SSH key, prompting for passphrase");
let pass =
rpassword::prompt_password(format!("Enter passphrase for key {key_path:?}: "))
.with_context(|| "Failed to read passphrase")?;
Some(pass)
} else {
None
};

return Ok(AuthMethod::with_key_file(key_path, passphrase.as_deref()));
}

// If no explicit key path, try SSH agent if available (auto-detect)
#[cfg(not(target_os = "windows"))]
if !self.use_agent && std::env::var("SSH_AUTH_SOCK").is_ok() {
tracing::debug!("SSH agent detected, attempting agent authentication");
return Ok(AuthMethod::Agent);
}

// Try to find SSH key
let ssh_key_paths = vec![
dirs::home_dir().map(|h| h.join(".ssh/id_rsa")),
dirs::home_dir().map(|h| h.join(".ssh/id_ed25519")),
// Fallback to default key locations
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let home_path = Path::new(&home).join(".ssh");

// Try common key files in order of preference
let default_keys = [
home_path.join("id_ed25519"),
home_path.join("id_rsa"),
home_path.join("id_ecdsa"),
home_path.join("id_dsa"),
];

for key_path in ssh_key_paths.into_iter().flatten() {
if key_path.exists() {
return Ok(AuthMethod::with_key_file(key_path, None));
for default_key in &default_keys {
if default_key.exists() {
tracing::debug!("Using default key: {:?}", default_key);

// Check if the key is encrypted
let key_contents = std::fs::read_to_string(default_key)
.with_context(|| format!("Failed to read SSH key file: {default_key:?}"))?;

let passphrase = if key_contents.contains("ENCRYPTED")
|| key_contents.contains("Proc-Type: 4,ENCRYPTED")
{
tracing::debug!("Detected encrypted SSH key, prompting for passphrase");
let pass = rpassword::prompt_password(format!(
"Enter passphrase for key {default_key:?}: "
))
.with_context(|| "Failed to read passphrase")?;
Some(pass)
} else {
None
};

return Ok(AuthMethod::with_key_file(
default_key,
passphrase.as_deref(),
));
}
}

// If no key found, prompt for password
let password =
rpassword::prompt_password(format!("Password for {}@{}: ", node.username, node.host))?;
Ok(AuthMethod::with_password(&password))
anyhow::bail!(
"SSH authentication failed: No authentication method available.\n\
Tried:\n\
- SSH agent: {}\n\
- SSH keys: {:?}\n\
Please ensure you have a valid SSH key or SSH agent running.",
if std::env::var("SSH_AUTH_SOCK").is_ok() {
"Available but no identities"
} else {
"Not available (SSH_AUTH_SOCK not set)"
},
default_keys
)
}

/// Run interactive mode with a single node
Expand Down Expand Up @@ -832,6 +930,10 @@ mod tests {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

let path = PathBuf::from("~/test/file.txt");
Expand All @@ -856,6 +958,10 @@ mod tests {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

let node = Node::new(String::from("example.com"), 22, String::from("alice"));
Expand Down
13 changes: 13 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ async fn main() -> Result<()> {

let merged_work_dir = work_dir.or(interactive_config.work_dir.clone());

// Determine SSH key path: CLI argument takes precedence over config
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};

let interactive_cmd = InteractiveCommand {
single_node: merged_mode.0,
multiplex: merged_mode.1,
Expand All @@ -251,6 +260,10 @@ async fn main() -> Result<()> {
config: config.clone(),
interactive_config,
cluster_name: cluster_name.map(String::from),
key_path,
use_agent: cli.use_agent,
use_password: cli.password,
strict_mode,
};
let result = interactive_cmd.execute().await?;
println!("\nInteractive session ended.");
Expand Down
37 changes: 37 additions & 0 deletions tests/interactive_integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use bssh::commands::interactive::InteractiveCommand;
use bssh::config::{Config, InteractiveConfig};
use bssh::node::Node;
use bssh::ssh::known_hosts::StrictHostKeyChecking;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
Expand All @@ -41,6 +42,10 @@ fn test_interactive_command_builder() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert!(!cmd.single_node);
Expand All @@ -66,6 +71,10 @@ fn test_history_file_handling() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert_eq!(cmd.history_file, history_path);
Expand Down Expand Up @@ -154,6 +163,10 @@ async fn test_interactive_with_unreachable_nodes() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

// This should fail to connect
Expand All @@ -179,6 +192,10 @@ async fn test_interactive_with_no_nodes() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

let result = cmd.execute().await;
Expand Down Expand Up @@ -214,6 +231,10 @@ fn test_mode_configuration() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert!(single_cmd.single_node);
Expand All @@ -230,6 +251,10 @@ fn test_mode_configuration() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert!(!multi_cmd.single_node);
Expand All @@ -249,6 +274,10 @@ fn test_working_directory_config() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert_eq!(cmd_with_dir.work_dir, Some("/var/www".to_string()));
Expand All @@ -263,6 +292,10 @@ fn test_working_directory_config() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert_eq!(cmd_without_dir.work_dir, None);
Expand All @@ -289,6 +322,10 @@ fn test_prompt_format() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert_eq!(cmd.prompt_format, format);
Expand Down
9 changes: 9 additions & 0 deletions tests/interactive_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use bssh::commands::interactive::InteractiveCommand;
use bssh::config::{Config, InteractiveConfig};
use bssh::node::Node;
use bssh::ssh::known_hosts::StrictHostKeyChecking;
use std::path::PathBuf;

#[tokio::test]
Expand All @@ -29,6 +30,10 @@ async fn test_interactive_command_creation() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

assert!(!cmd.single_node);
Expand All @@ -48,6 +53,10 @@ async fn test_interactive_with_no_nodes() {
config: Config::default(),
interactive_config: InteractiveConfig::default(),
cluster_name: None,
key_path: None,
use_agent: false,
use_password: false,
strict_mode: StrictHostKeyChecking::AcceptNew,
};

let result = cmd.execute().await;
Expand Down