|
12 | 12 | // See the License for the specific language governing permissions and |
13 | 13 | // limitations under the License. |
14 | 14 |
|
| 15 | +use super::tokio_client::{AuthMethod, Client}; |
15 | 16 | use anyhow::{Context, Result}; |
16 | | -use async_ssh2_tokio::{AuthMethod, Client}; |
17 | 17 | use std::path::Path; |
18 | 18 | use std::time::Duration; |
19 | 19 |
|
@@ -336,3 +336,123 @@ impl CommandResult { |
336 | 336 | self.exit_status == 0 |
337 | 337 | } |
338 | 338 | } |
| 339 | + |
| 340 | +#[cfg(test)] |
| 341 | +mod tests { |
| 342 | + use super::*; |
| 343 | + use tempfile::TempDir; |
| 344 | + |
| 345 | + #[test] |
| 346 | + fn test_ssh_client_creation() { |
| 347 | + let client = SshClient::new("example.com".to_string(), 22, "user".to_string()); |
| 348 | + assert_eq!(client.host, "example.com"); |
| 349 | + assert_eq!(client.port, 22); |
| 350 | + assert_eq!(client.username, "user"); |
| 351 | + } |
| 352 | + |
| 353 | + #[test] |
| 354 | + fn test_command_result_success() { |
| 355 | + let result = CommandResult { |
| 356 | + host: "test.com".to_string(), |
| 357 | + output: b"Hello World\n".to_vec(), |
| 358 | + stderr: Vec::new(), |
| 359 | + exit_status: 0, |
| 360 | + }; |
| 361 | + |
| 362 | + assert!(result.is_success()); |
| 363 | + assert_eq!(result.stdout_string(), "Hello World\n"); |
| 364 | + assert_eq!(result.stderr_string(), ""); |
| 365 | + } |
| 366 | + |
| 367 | + #[test] |
| 368 | + fn test_command_result_failure() { |
| 369 | + let result = CommandResult { |
| 370 | + host: "test.com".to_string(), |
| 371 | + output: Vec::new(), |
| 372 | + stderr: b"Command not found\n".to_vec(), |
| 373 | + exit_status: 127, |
| 374 | + }; |
| 375 | + |
| 376 | + assert!(!result.is_success()); |
| 377 | + assert_eq!(result.stdout_string(), ""); |
| 378 | + assert_eq!(result.stderr_string(), "Command not found\n"); |
| 379 | + } |
| 380 | + |
| 381 | + #[test] |
| 382 | + fn test_command_result_with_utf8() { |
| 383 | + let result = CommandResult { |
| 384 | + host: "test.com".to_string(), |
| 385 | + output: "한글 테스트\n".as_bytes().to_vec(), |
| 386 | + stderr: "エラー\n".as_bytes().to_vec(), |
| 387 | + exit_status: 1, |
| 388 | + }; |
| 389 | + |
| 390 | + assert!(!result.is_success()); |
| 391 | + assert_eq!(result.stdout_string(), "한글 테스트\n"); |
| 392 | + assert_eq!(result.stderr_string(), "エラー\n"); |
| 393 | + } |
| 394 | + |
| 395 | + #[test] |
| 396 | + fn test_determine_auth_method_with_key() { |
| 397 | + let temp_dir = TempDir::new().unwrap(); |
| 398 | + let key_path = temp_dir.path().join("test_key"); |
| 399 | + std::fs::write(&key_path, "fake key content").unwrap(); |
| 400 | + |
| 401 | + let client = SshClient::new("test.com".to_string(), 22, "user".to_string()); |
| 402 | + let auth = client |
| 403 | + .determine_auth_method(Some(&key_path), false) |
| 404 | + .unwrap(); |
| 405 | + |
| 406 | + match auth { |
| 407 | + AuthMethod::PrivateKeyFile { key_file_path, .. } => { |
| 408 | + assert_eq!(key_file_path, key_path); |
| 409 | + } |
| 410 | + _ => panic!("Expected PrivateKeyFile auth method"), |
| 411 | + } |
| 412 | + } |
| 413 | + |
| 414 | + #[cfg(not(target_os = "windows"))] |
| 415 | + #[test] |
| 416 | + fn test_determine_auth_method_with_agent() { |
| 417 | + unsafe { |
| 418 | + std::env::set_var("SSH_AUTH_SOCK", "/tmp/ssh-agent.sock"); |
| 419 | + } |
| 420 | + |
| 421 | + let client = SshClient::new("test.com".to_string(), 22, "user".to_string()); |
| 422 | + let auth = client.determine_auth_method(None, true).unwrap(); |
| 423 | + |
| 424 | + match auth { |
| 425 | + AuthMethod::Agent => {} |
| 426 | + _ => panic!("Expected Agent auth method"), |
| 427 | + } |
| 428 | + |
| 429 | + unsafe { |
| 430 | + std::env::remove_var("SSH_AUTH_SOCK"); |
| 431 | + } |
| 432 | + } |
| 433 | + |
| 434 | + #[test] |
| 435 | + fn test_determine_auth_method_fallback_to_default() { |
| 436 | + // Create a fake home directory with default key |
| 437 | + let temp_dir = TempDir::new().unwrap(); |
| 438 | + let ssh_dir = temp_dir.path().join(".ssh"); |
| 439 | + std::fs::create_dir_all(&ssh_dir).unwrap(); |
| 440 | + let default_key = ssh_dir.join("id_rsa"); |
| 441 | + std::fs::write(&default_key, "fake key").unwrap(); |
| 442 | + |
| 443 | + unsafe { |
| 444 | + std::env::set_var("HOME", temp_dir.path().to_str().unwrap()); |
| 445 | + std::env::remove_var("SSH_AUTH_SOCK"); |
| 446 | + } |
| 447 | + |
| 448 | + let client = SshClient::new("test.com".to_string(), 22, "user".to_string()); |
| 449 | + let auth = client.determine_auth_method(None, false).unwrap(); |
| 450 | + |
| 451 | + match auth { |
| 452 | + AuthMethod::PrivateKeyFile { key_file_path, .. } => { |
| 453 | + assert_eq!(key_file_path, default_key); |
| 454 | + } |
| 455 | + _ => panic!("Expected PrivateKeyFile auth method"), |
| 456 | + } |
| 457 | + } |
| 458 | +} |
0 commit comments