diff --git a/Cargo.toml b/Cargo.toml index dbe9ec09..566aae2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ repository = "https://github.com/lablup/bssh" readme = "README.md" keywords = ["cli", "rust"] categories = ["command-line-utilities"] -edition = "2021" +edition = "2024" [dependencies] bytes = "1.11.1" diff --git a/benches/large_output_benchmark.rs b/benches/large_output_benchmark.rs index 5cb49b80..787f3728 100644 --- a/benches/large_output_benchmark.rs +++ b/benches/large_output_benchmark.rs @@ -25,9 +25,9 @@ use bssh::node::Node; use bssh::ssh::tokio_client::CommandOutput; use bssh::ui::tui::app::TuiApp; use bytes::Bytes; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use ratatui::backend::TestBackend; +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use ratatui::Terminal; +use ratatui::backend::TestBackend; use std::hint::black_box; use tokio::runtime::Runtime; use tokio::sync::mpsc; diff --git a/crates/bssh-russh/Cargo.toml b/crates/bssh-russh/Cargo.toml index 09078068..ab1433fd 100644 --- a/crates/bssh-russh/Cargo.toml +++ b/crates/bssh-russh/Cargo.toml @@ -4,7 +4,7 @@ version = "0.60.1" authors = ["Jeongkyu Shin "] description = "Temporary fork of russh with high-frequency PTY output fix (Handle::data from spawned tasks)" documentation = "https://docs.rs/bssh-russh" -edition = "2021" +edition = "2024" homepage = "https://github.com/lablup/bssh" keywords = ["ssh"] license = "Apache-2.0" diff --git a/crates/bssh-russh/src/client/encrypted.rs b/crates/bssh-russh/src/client/encrypted.rs index 900847a8..34d22821 100644 --- a/crates/bssh-russh/src/client/encrypted.rs +++ b/crates/bssh-russh/src/client/encrypted.rs @@ -437,15 +437,14 @@ impl Session { let channel_num = map_err!(ChannelId::decode(&mut r))?; let data = map_err!(Bytes::decode(&mut r))?; let target = self.common.config.window_size; - if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, &data, target)? { + if let Some(ref mut enc) = self.common.encrypted + && enc.adjust_window_size(channel_num, &data, target)? { let next_window = client.adjust_window(channel_num, self.target_window_size); if next_window > 0 { self.target_window_size = next_window } } - } if let Some(chan) = self.channels.get(&channel_num) { let _ = chan.send(ChannelMsg::Data { data: data.clone() }).await; @@ -459,15 +458,14 @@ impl Session { let extended_code = map_err!(u32::decode(&mut r))?; let data = map_err!(Bytes::decode(&mut r))?; let target = self.common.config.window_size; - if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, &data, target)? { + if let Some(ref mut enc) = self.common.encrypted + && enc.adjust_window_size(channel_num, &data, target)? { let next_window = client.adjust_window(channel_num, self.target_window_size); if next_window > 0 { self.target_window_size = next_window } } - } if let Some(chan) = self.channels.get(&channel_num) { let _ = chan @@ -551,8 +549,8 @@ impl Session { } _ => { let wants_reply = map_err!(u8::decode(&mut r))?; - if wants_reply == 1 { - if let Some(ref mut enc) = self.common.encrypted { + if wants_reply == 1 + && let Some(ref mut enc) = self.common.encrypted { self.common.wants_reply = false; if let Some(ch) = enc.channels.get(&channel_num) { push_packet!(enc.write, { @@ -561,7 +559,6 @@ impl Session { }) } } - } info!("Unknown channel request {req:?} {wants_reply:?}",); Ok(()) } diff --git a/crates/bssh-russh/src/client/mod.rs b/crates/bssh-russh/src/client/mod.rs index 5f2e5088..60df7453 100644 --- a/crates/bssh-russh/src/client/mod.rs +++ b/crates/bssh-russh/src/client/mod.rs @@ -966,11 +966,10 @@ pub async fn connect( handler: H, ) -> Result, H::Error> { let socket = map_err!(tokio::net::TcpStream::connect(addrs).await)?; - if config.as_ref().nodelay { - if let Err(e) = socket.set_nodelay(true) { + if config.as_ref().nodelay + && let Err(e) = socket.set_nodelay(true) { warn!("set_nodelay() failed: {e:?}"); } - } connect_stream(config, socket, handler).await } @@ -1211,8 +1210,8 @@ impl Session { reading.set(start_reading(stream_read, buffer, opening_cipher)); } () = &mut keepalive_timer => { - if let Some(ref mut enc) = self.common.encrypted { - if matches!(enc.state, EncryptedState::Authenticated) { + if let Some(ref mut enc) = self.common.encrypted + && matches!(enc.state, EncryptedState::Authenticated) { self.common.alive_timeouts = self.common.alive_timeouts.saturating_add(1); if self.common.config.keepalive_max != 0 && self.common.alive_timeouts > self.common.config.keepalive_max { debug!("Timeout, server not responding to keepalives"); @@ -1221,7 +1220,6 @@ impl Session { sent_keepalive = true; self.send_keepalive(true)?; } - } } () = &mut inactivity_timer => { debug!("timeout"); @@ -1263,15 +1261,14 @@ impl Session { self.flush()?; map_err!(self.common.packet_writer.flush_into(stream_write).await)?; - if let Some(ref mut enc) = self.common.encrypted { - if let EncryptedState::InitCompression = enc.state { + if let Some(ref mut enc) = self.common.encrypted + && let EncryptedState::InitCompression = enc.state { if enc.client_compression.is_deferred() { enc.client_compression .init_compress(self.common.packet_writer.compress()); } enc.state = EncryptedState::Authenticated; } - } if self.common.received_data { // Reset the number of failed keepalive attempts. We don't @@ -1281,22 +1278,20 @@ impl Session { // data from it. self.common.alive_timeouts = 0; } - if self.common.received_data || sent_keepalive { - if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + if (self.common.received_data || sent_keepalive) + && let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( keepalive_timer.as_mut().as_pin_mut(), self.common.config.keepalive_interval, ) { sleep.as_mut().reset(tokio::time::Instant::now() + d); } - } - if !sent_keepalive { - if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + if !sent_keepalive + && let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( inactivity_timer.as_mut().as_pin_mut(), self.common.config.inactivity_timeout, ) { sleep.as_mut().reset(tokio::time::Instant::now() + d); } - } } result @@ -1528,15 +1523,14 @@ impl Session { /// Flush the temporary cleartext buffer into the encryption /// buffer. This does *not* flush to the socket. fn flush(&mut self) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if enc.flush( + if let Some(ref mut enc) = self.common.encrypted + && enc.flush( &self.common.config.as_ref().limits, &mut self.common.packet_writer, )? && !self.kex.active() { self.begin_rekey()?; } - } Ok(()) } @@ -1581,8 +1575,8 @@ async fn reply( let is_kex_msg = pkt.buffer.first().cloned().map(is_kex_msg).unwrap_or(false); - if is_kex_msg { - if let SessionKexState::InProgress(kex) = session.kex.take() { + if is_kex_msg + && let SessionKexState::InProgress(kex) = session.kex.take() { let progress = kex.step(Some(pkt), &mut session.common.packet_writer)?; match progress { @@ -1652,7 +1646,6 @@ async fn reply( return Ok(()); } - } session.client_read_encrypted(handler, pkt).await } diff --git a/crates/bssh-russh/src/client/session.rs b/crates/bssh-russh/src/client/session.rs index 3a5ed1a1..4cdd2e5c 100644 --- a/crates/bssh-russh/src/client/session.rs +++ b/crates/bssh-russh/src/client/session.rs @@ -111,8 +111,8 @@ impl Session { pix_height: u32, terminal_modes: &[(Pty, u32)], ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { map_err!(msg::CHANNEL_REQUEST.encode(&mut enc.write))?; @@ -137,7 +137,6 @@ impl Session { (Pty::TTY_OP_END as u8).encode(&mut enc.write)?; }); } - } Ok(()) } @@ -150,8 +149,8 @@ impl Session { x11_authentication_cookie: &str, x11_screen_number: u32, ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -164,7 +163,6 @@ impl Session { x11_screen_number.encode(&mut enc.write)?; }); } - } Ok(()) } @@ -175,8 +173,8 @@ impl Session { variable_name: &str, variable_value: &str, ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -187,7 +185,6 @@ impl Session { variable_value.encode(&mut enc.write)?; }); } - } Ok(()) } @@ -196,8 +193,8 @@ impl Session { want_reply: bool, channel: ChannelId, ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -206,7 +203,6 @@ impl Session { (want_reply as u8).encode(&mut enc.write)?; }); } - } Ok(()) } @@ -216,8 +212,8 @@ impl Session { want_reply: bool, command: &[u8], ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -228,14 +224,13 @@ impl Session { }); return Ok(()); } - } error!("exec"); Ok(()) } pub fn signal(&mut self, channel: ChannelId, signal: Sig) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; channel.recipient_channel.encode(&mut enc.write)?; @@ -244,7 +239,6 @@ impl Session { signal.name().encode(&mut enc.write)?; }); } - } Ok(()) } @@ -254,8 +248,8 @@ impl Session { channel: ChannelId, name: &str, ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -265,7 +259,6 @@ impl Session { name.encode(&mut enc.write)?; }); } - } Ok(()) } @@ -277,8 +270,8 @@ impl Session { pix_width: u32, pix_height: u32, ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -291,7 +284,6 @@ impl Session { pix_height.encode(&mut enc.write)?; }); } - } Ok(()) } @@ -484,8 +476,8 @@ impl Session { channel: ChannelId, want_reply: bool, ) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; channel.recipient_channel.encode(&mut enc.write)?; @@ -493,7 +485,6 @@ impl Session { (want_reply as u8).encode(&mut enc.write)?; }); } - } Ok(()) } diff --git a/crates/bssh-russh/src/kex/dh/mod.rs b/crates/bssh-russh/src/kex/dh/mod.rs index 730a2f2e..723a3740 100644 --- a/crates/bssh-russh/src/kex/dh/mod.rs +++ b/crates/bssh-russh/src/kex/dh/mod.rs @@ -121,11 +121,10 @@ impl std::fmt::Debug for DhGroupKex { pub(crate) fn biguint_to_mpint(biguint: &BigUint) -> Vec { let mut mpint = Vec::new(); let bytes = biguint.to_bytes_be(); - if let Some(b) = bytes.first() { - if b > &0x7f { + if let Some(b) = bytes.first() + && b > &0x7f { mpint.push(0); } - } mpint.extend(&bytes); mpint } diff --git a/crates/bssh-russh/src/keys/known_hosts.rs b/crates/bssh-russh/src/keys/known_hosts.rs index 058f36f5..e356e6bb 100644 --- a/crates/bssh-russh/src/keys/known_hosts.rs +++ b/crates/bssh-russh/src/keys/known_hosts.rs @@ -116,11 +116,10 @@ fn match_hostname(host: &str, pattern: &str) -> bool { let Some(Ok(hash)) = parts.next().map(|p| BASE64_MIME.decode(p.as_bytes())) else { continue; }; - if let Ok(hmac) = Hmac::::new_from_slice(&salt) { - if hmac.chain_update(host).verify_slice(&hash).is_ok() { + if let Ok(hmac) = Hmac::::new_from_slice(&salt) + && hmac.chain_update(host).verify_slice(&hash).is_ok() { return true; } - } } else if host == entry { return true; } diff --git a/crates/bssh-russh/src/server/encrypted.rs b/crates/bssh-russh/src/server/encrypted.rs index d4dcdb82..2efe9d01 100644 --- a/crates/bssh-russh/src/server/encrypted.rs +++ b/crates/bssh-russh/src/server/encrypted.rs @@ -632,14 +632,13 @@ impl Session { let data = map_err!(Bytes::decode(r))?; let target = self.target_window_size; - if let Some(ref mut enc) = self.common.encrypted { - if enc.adjust_window_size(channel_num, &data, target)? { + if let Some(ref mut enc) = self.common.encrypted + && enc.adjust_window_size(channel_num, &data, target)? { let window = handler.adjust_window(channel_num, self.target_window_size); if window > 0 { self.target_window_size = window } } - } self.flush()?; if let Some(ext) = ext { if let Some(chan) = self.channels.get(&channel_num) { @@ -731,11 +730,10 @@ impl Session { let channel_num = map_err!(ChannelId::decode(r))?; let req_type = map_err!(String::decode(r))?; let wants_reply = map_err!(u8::decode(r))?; - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get_mut(&channel_num) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get_mut(&channel_num) { channel.wants_reply = wants_reply != 0; } - } match req_type.as_str() { "pty-req" => { let term = map_err!(String::decode(r))?; diff --git a/crates/bssh-russh/src/server/mod.rs b/crates/bssh-russh/src/server/mod.rs index b57cd074..ee6f66e2 100644 --- a/crates/bssh-russh/src/server/mod.rs +++ b/crates/bssh-russh/src/server/mod.rs @@ -879,11 +879,10 @@ pub trait Server { let error_tx = error_tx.clone(); russh_util::runtime::spawn(async move { - if config.nodelay { - if let Err(e) = socket.set_nodelay(true) { + if config.nodelay + && let Err(e) = socket.set_nodelay(true) { warn!("set_nodelay() failed: {e:?}"); } - } let session = match run_stream(config, socket, handler).await { Ok(s) => s, @@ -1096,8 +1095,8 @@ async fn reply( let is_kex_msg = pkt.buffer.first().cloned().map(is_kex_msg).unwrap_or(false); - if is_kex_msg { - if let SessionKexState::InProgress(kex) = session.kex.take() { + if is_kex_msg + && let SessionKexState::InProgress(kex) = session.kex.take() { let progress = kex .step(Some(pkt), &mut session.common.packet_writer, handler) .await?; @@ -1158,7 +1157,6 @@ async fn reply( return Ok(()); } - } // Handle key exchange/re-exchange. session.server_read_encrypted(handler, pkt).await diff --git a/crates/bssh-russh/src/server/session.rs b/crates/bssh-russh/src/server/session.rs index ea7b84f6..9a35c495 100644 --- a/crates/bssh-russh/src/server/session.rs +++ b/crates/bssh-russh/src/server/session.rs @@ -789,22 +789,20 @@ impl Session { // data from it. self.common.alive_timeouts = 0; } - if self.common.received_data || sent_keepalive { - if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + if (self.common.received_data || sent_keepalive) + && let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( keepalive_timer.as_mut().as_pin_mut(), self.common.config.keepalive_interval, ) { sleep.as_mut().reset(tokio::time::Instant::now() + d); } - } - if !sent_keepalive { - if let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( + if !sent_keepalive + && let (futures::future::Either::Right(ref mut sleep), Some(d)) = ( inactivity_timer.as_mut().as_pin_mut(), self.common.config.inactivity_timeout, ) { sleep.as_mut().reset(tokio::time::Instant::now() + d); } - } } debug!("disconnected"); // Shutdown @@ -833,38 +831,35 @@ impl Session { } pub fn writable_packet_size(&self, channel: &ChannelId) -> u32 { - if let Some(ref enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(channel) { + if let Some(ref enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(channel) { return channel .sender_window_size .min(channel.sender_maximum_packet_size); } - } 0 } pub fn window_size(&self, channel: &ChannelId) -> u32 { - if let Some(ref enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(channel) { + if let Some(ref enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(channel) { return channel.sender_window_size; } - } 0 } pub fn max_packet_size(&self, channel: &ChannelId) -> u32 { - if let Some(ref enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(channel) { + if let Some(ref enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(channel) { return channel.sender_maximum_packet_size; } - } 0 } /// Flush the session, i.e. encrypt the pending buffer. pub fn flush(&mut self) -> Result<(), Error> { - if let Some(ref mut enc) = self.common.encrypted { - if enc.flush( + if let Some(ref mut enc) = self.common.encrypted + && enc.flush( &self.common.config.as_ref().limits, &mut self.common.packet_writer, )? && self.kex == SessionKexState::Idle @@ -874,7 +869,6 @@ impl Session { self.begin_rekey()?; } } - } Ok(()) } @@ -949,12 +943,11 @@ impl Session { /// cancelling). Always call this function if the request was /// successful (it checks whether the client expects an answer). pub fn request_success(&mut self) { - if self.common.wants_reply { - if let Some(ref mut enc) = self.common.encrypted { + if self.common.wants_reply + && let Some(ref mut enc) = self.common.encrypted { self.common.wants_reply = false; push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS)) } - } } /// Send a "failure" reply to a global request. @@ -969,8 +962,8 @@ impl Session { /// function if the request was successful (it checks whether the /// client expects an answer). pub fn channel_success(&mut self, channel: ChannelId) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get_mut(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get_mut(&channel) { assert!(channel.confirmed); if channel.wants_reply { channel.wants_reply = false; @@ -981,14 +974,13 @@ impl Session { }) } } - } Ok(()) } /// Send a "failure" reply to a global request. pub fn channel_failure(&mut self, channel: ChannelId) -> Result<(), crate::Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get_mut(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get_mut(&channel) { assert!(channel.confirmed); if channel.wants_reply { channel.wants_reply = false; @@ -998,7 +990,6 @@ impl Session { }) } } - } Ok(()) } @@ -1081,8 +1072,8 @@ impl Session { channel: ChannelId, client_can_do: bool, ) -> Result<(), Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -1093,7 +1084,6 @@ impl Session { (client_can_do as u8).encode(&mut enc.write)?; }) } - } Ok(()) } @@ -1133,8 +1123,8 @@ impl Session { channel: ChannelId, exit_status: u32, ) -> Result<(), Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -1145,7 +1135,6 @@ impl Session { exit_status.encode(&mut enc.write)?; }) } - } Ok(()) } @@ -1158,8 +1147,8 @@ impl Session { error_message: &str, language_tag: &str, ) -> Result<(), Error> { - if let Some(ref mut enc) = self.common.encrypted { - if let Some(channel) = enc.channels.get(&channel) { + if let Some(ref mut enc) = self.common.encrypted + && let Some(channel) = enc.channels.get(&channel) { assert!(channel.confirmed); push_packet!(enc.write, { msg::CHANNEL_REQUEST.encode(&mut enc.write)?; @@ -1173,7 +1162,6 @@ impl Session { language_tag.encode(&mut enc.write)?; }) } - } Ok(()) } diff --git a/crates/bssh-russh/src/ssh_read.rs b/crates/bssh-russh/src/ssh_read.rs index 3656a6af..548b6eee 100644 --- a/crates/bssh-russh/src/ssh_read.rs +++ b/crates/bssh-russh/src/ssh_read.rs @@ -156,12 +156,11 @@ impl SshRead { if i >= 8 { // Check if we have a valid SSH protocol identifier #[allow(clippy::indexing_slicing)] - if let Ok(s) = std::str::from_utf8(&ssh_id.buf[..i]) { - if s.starts_with("SSH-1.99-") || s.starts_with("SSH-2.0-") { + if let Ok(s) = std::str::from_utf8(&ssh_id.buf[..i]) + && (s.starts_with("SSH-1.99-") || s.starts_with("SSH-2.0-")) { ssh_id.sshid_len = i; return Ok(ssh_id.id()); } - } } // Else, it is a "preliminary" (see // https://tools.ietf.org/html/rfc4253#section-4.2), diff --git a/src/app/dispatcher.rs b/src/app/dispatcher.rs index 0475d127..1c1d14a6 100644 --- a/src/app/dispatcher.rs +++ b/src/app/dispatcher.rs @@ -19,23 +19,23 @@ use bssh::{ cli::{Cli, Commands}, commands::{ download::download_file, - exec::{execute_command, ExecuteCommandParams}, + exec::{ExecuteCommandParams, execute_command}, interactive::InteractiveCommand, list::list_clusters, ping::ping_nodes, - upload::{upload_file, FileTransferParams}, + upload::{FileTransferParams, upload_file}, }, config::InteractiveMode, pty::PtyConfig, security::get_sudo_password, - ssh::tokio_client::{SshConnectionConfig, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_MAX}, + ssh::tokio_client::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_MAX, SshConnectionConfig}, }; use std::path::{Path, PathBuf}; use std::sync::Arc; #[cfg(target_os = "macos")] use super::initialization::determine_use_keychain; -use super::initialization::{determine_ssh_key_path, AppContext}; +use super::initialization::{AppContext, determine_ssh_key_path}; use super::utils::format_duration; /// Build SSH connection config with keepalive settings. diff --git a/src/app/initialization.rs b/src/app/initialization.rs index e8b006de..f259158a 100644 --- a/src/app/initialization.rs +++ b/src/app/initialization.rs @@ -20,7 +20,7 @@ use bssh::{ config::Config, jump::parse_jump_hosts, node::Node, - ssh::{known_hosts::StrictHostKeyChecking, SshConfig}, + ssh::{SshConfig, known_hosts::StrictHostKeyChecking}, utils::init_logging, }; use std::path::PathBuf; @@ -275,15 +275,15 @@ pub fn determine_strict_host_key_checking( } // SSH config value for specific hostname - if let Some(host) = hostname { - if let Some(ssh_config_value) = ssh_config.get_strict_host_key_checking(host) { - return match ssh_config_value.to_lowercase().as_str() { - "yes" => StrictHostKeyChecking::Yes, - "no" => StrictHostKeyChecking::No, - "ask" | "accept-new" => StrictHostKeyChecking::AcceptNew, - _ => StrictHostKeyChecking::AcceptNew, - }; - } + if let Some(host) = hostname + && let Some(ssh_config_value) = ssh_config.get_strict_host_key_checking(host) + { + return match ssh_config_value.to_lowercase().as_str() { + "yes" => StrictHostKeyChecking::Yes, + "no" => StrictHostKeyChecking::No, + "ask" | "accept-new" => StrictHostKeyChecking::AcceptNew, + _ => StrictHostKeyChecking::AcceptNew, + }; } // Default from CLI (already parsed) diff --git a/src/app/nodes.rs b/src/app/nodes.rs index 6e248ae8..1789a755 100644 --- a/src/app/nodes.rs +++ b/src/app/nodes.rs @@ -562,10 +562,12 @@ mod tests { let result = exclude_nodes(nodes, &patterns); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("too many wildcards")); + assert!( + result + .unwrap_err() + .to_string() + .contains("too many wildcards") + ); } #[test] @@ -575,10 +577,12 @@ mod tests { let result = exclude_nodes(nodes, &patterns); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("invalid characters")); + assert!( + result + .unwrap_err() + .to_string() + .contains("invalid characters") + ); } #[test] @@ -588,10 +592,12 @@ mod tests { let result = exclude_nodes(nodes, &patterns); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("invalid sequences")); + assert!( + result + .unwrap_err() + .to_string() + .contains("invalid sequences") + ); } #[test] diff --git a/src/bin/bssh_keygen.rs b/src/bin/bssh_keygen.rs index 4ecd7474..8bcbd05a 100644 --- a/src/bin/bssh_keygen.rs +++ b/src/bin/bssh_keygen.rs @@ -123,11 +123,11 @@ fn main() -> Result<()> { }; // Ensure parent directory exists - if let Some(parent) = output.parent() { - if !parent.exists() { - std::fs::create_dir_all(parent) - .with_context(|| format!("Failed to create directory: {}", parent.display()))?; - } + if let Some(parent) = output.parent() + && !parent.exists() + { + std::fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory: {}", parent.display()))?; } // Check if file exists and prompt for overwrite diff --git a/src/bin/bssh_server.rs b/src/bin/bssh_server.rs index 98776219..5c700e30 100644 --- a/src/bin/bssh_server.rs +++ b/src/bin/bssh_server.rs @@ -17,8 +17,8 @@ //! This binary provides a command-line interface for managing the bssh SSH server. use anyhow::{Context, Result}; -use bssh::server::config::{generate_config_template, load_config, ServerFileConfig}; use bssh::server::BsshServer; +use bssh::server::config::{ServerFileConfig, generate_config_template, load_config}; use bssh::utils::logging; use clap::{ArgAction, Parser, Subcommand}; use std::fs; @@ -471,37 +471,36 @@ fn setup_signal_handlers() -> Result> { /// Write the current process ID to a PID file fn write_pid_file(path: &PathBuf) -> Result<()> { // Check if PID file already exists and refers to a running process - if path.exists() { - if let Ok(existing_pid_str) = fs::read_to_string(path) { - if let Ok(existing_pid) = existing_pid_str.trim().parse::() { - // Check if process is still running - #[cfg(unix)] - { - use nix::sys::signal::kill; - use nix::unistd::Pid; - - let pid = Pid::from_raw(existing_pid); - // Use signal 0 (None) to check if process exists without sending actual signal - if kill(pid, None).is_ok() { - anyhow::bail!( - "Another instance is already running with PID {}. \ + if path.exists() + && let Ok(existing_pid_str) = fs::read_to_string(path) + && let Ok(existing_pid) = existing_pid_str.trim().parse::() + { + // Check if process is still running + #[cfg(unix)] + { + use nix::sys::signal::kill; + use nix::unistd::Pid; + + let pid = Pid::from_raw(existing_pid); + // Use signal 0 (None) to check if process exists without sending actual signal + if kill(pid, None).is_ok() { + anyhow::bail!( + "Another instance is already running with PID {}. \ If this is incorrect, remove {} and try again.", - existing_pid, - path.display() - ); - } - } - - #[cfg(not(unix))] - { - // On non-Unix systems, warn but allow overwrite - tracing::warn!( - "PID file exists with PID {}. Overwriting (process check not available on this platform).", - existing_pid - ); - } + existing_pid, + path.display() + ); } } + + #[cfg(not(unix))] + { + // On non-Unix systems, warn but allow overwrite + tracing::warn!( + "PID file exists with PID {}. Overwriting (process check not available on this platform).", + existing_pid + ); + } } let pid = std::process::id(); @@ -798,10 +797,12 @@ mod tests { let result = gen_host_key("rsa", &key_path, 1024); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("RSA key size must be at least 2048")); + assert!( + result + .unwrap_err() + .to_string() + .contains("RSA key size must be at least 2048") + ); } #[test] diff --git a/src/cli/bssh.rs b/src/cli/bssh.rs index 2eac6c86..008f1e61 100644 --- a/src/cli/bssh.rs +++ b/src/cli/bssh.rs @@ -444,12 +444,12 @@ pub enum Commands { impl Cli { pub fn get_command(&self) -> String { // In multi-server mode with destination, treat destination as first command arg - if self.is_multi_server_mode() { - if let Some(dest) = &self.destination { - let mut all_args = vec![dest.clone()]; - all_args.extend(self.command_args.clone()); - return all_args.join(" "); - } + if self.is_multi_server_mode() + && let Some(dest) = &self.destination + { + let mut all_args = vec![dest.clone()]; + all_args.extend(self.command_args.clone()); + return all_args.join(" "); } if !self.command_args.is_empty() { self.command_args.join(" ") @@ -567,10 +567,10 @@ impl Cli { // Check SSH options for Port= for opt in &self.ssh_options { - if let Some(port_str) = opt.strip_prefix("Port=") { - if let Ok(port) = port_str.parse::() { - return Some(port); - } + if let Some(port_str) = opt.strip_prefix("Port=") + && let Ok(port) = port_str.parse::() + { + return Some(port); } } diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 70e15ba1..a4981807 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -44,6 +44,6 @@ pub use bssh::{Cli, Commands}; // Re-export pdsh compatibility utilities pub use pdsh::{ - has_pdsh_compat_flag, is_pdsh_compat_mode, remove_pdsh_compat_flag, PdshCli, QueryResult, - PDSH_COMPAT_ENV_VAR, + PDSH_COMPAT_ENV_VAR, PdshCli, QueryResult, has_pdsh_compat_flag, is_pdsh_compat_mode, + remove_pdsh_compat_flag, }; diff --git a/src/commands/exec.rs b/src/commands/exec.rs index 41844616..c9204b93 100644 --- a/src/commands/exec.rs +++ b/src/commands/exec.rs @@ -20,9 +20,9 @@ use crate::executor::{ExitCodeStrategy, OutputMode, ParallelExecutor, RankDetect use crate::forwarding::ForwardingType; use crate::node::Node; use crate::security::SudoPassword; +use crate::ssh::SshConfig; use crate::ssh::known_hosts::StrictHostKeyChecking; use crate::ssh::tokio_client::SshConnectionConfig; -use crate::ssh::SshConfig; use crate::ui::OutputFormatter; use crate::utils::output::save_outputs_to_files; @@ -62,10 +62,10 @@ pub async fn execute_command(params: ExecuteCommandParams<'_>) -> Result<()> { ); // Handle port forwarding if specified - if let Some(ref forwards) = params.port_forwards { - if !forwards.is_empty() { - return execute_command_with_forwarding(params).await; - } + if let Some(ref forwards) = params.port_forwards + && !forwards.is_empty() + { + return execute_command_with_forwarding(params).await; } // Execute command without port forwarding (original behavior) @@ -247,11 +247,11 @@ async fn execute_command_without_forwarding(params: ExecuteCommandParams<'_>) -> // Save outputs to files if output_dir is specified and not already handled by file mode // (File mode already saves outputs, so only save for normal mode with output_dir) - if let Some(dir) = params.output_dir { - if !params.stream { - // Only save if not in stream mode (file mode saves automatically) - save_outputs_to_files(&results, dir, params.command).await?; - } + if let Some(dir) = params.output_dir + && !params.stream + { + // Only save if not in stream mode (file mode saves automatically) + save_outputs_to_files(&results, dir, params.command).await?; } // Print results (skip if already printed in stream mode) diff --git a/src/commands/interactive/connection.rs b/src/commands/interactive/connection.rs index 7b4318c5..70312d80 100644 --- a/src/commands/interactive/connection.rs +++ b/src/commands/interactive/connection.rs @@ -16,13 +16,13 @@ use anyhow::{Context, Result}; use crossterm::terminal; -use russh::client::Msg; use russh::Channel; +use russh::client::Msg; use std::io::{self, Write}; -use tokio::time::{timeout, Duration}; +use tokio::time::{Duration, timeout}; use zeroize::Zeroizing; -use crate::jump::{parse_jump_hosts, JumpHostChain}; +use crate::jump::{JumpHostChain, parse_jump_hosts}; use crate::node::Node; use crate::ssh::{ known_hosts::get_check_method, diff --git a/src/commands/interactive/execution.rs b/src/commands/interactive/execution.rs index 887d57ba..5646d3c7 100644 --- a/src/commands/interactive/execution.rs +++ b/src/commands/interactive/execution.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::pty::PtyManager; use super::super::interactive_signal::{ - reset_interrupt, setup_async_signal_handlers, setup_signal_handlers, TerminalGuard, + TerminalGuard, reset_interrupt, setup_async_signal_handlers, setup_signal_handlers, }; use super::types::{InteractiveCommand, InteractiveResult}; diff --git a/src/commands/interactive/multiplex.rs b/src/commands/interactive/multiplex.rs index 73ed7b4a..6ca5b2a7 100644 --- a/src/commands/interactive/multiplex.rs +++ b/src/commands/interactive/multiplex.rs @@ -17,14 +17,14 @@ use anyhow::Result; use chrono; use owo_colors::OwoColorize; +use rustyline::DefaultEditor; use rustyline::config::Configurer; use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; use tokio::time::Duration; use super::super::interactive_signal::is_interrupted; use super::types::{ - InteractiveCommand, NodeSession, NODES_TO_SHOW_IN_COMPACT, SSH_OUTPUT_POLL_INTERVAL_MS, + InteractiveCommand, NODES_TO_SHOW_IN_COMPACT, NodeSession, SSH_OUTPUT_POLL_INTERVAL_MS, }; impl InteractiveCommand { @@ -254,8 +254,8 @@ impl InteractiveCommand { // Try to read output from each active session _ = async { for session in &mut sessions { - if session.is_connected && session.is_active { - if let Ok(Some(output)) = session.read_output().await { + if session.is_connected && session.is_active + && let Ok(Some(output)) = session.read_output().await { has_output = true; // Print output with node prefix and optional timestamp for line in output.lines() { @@ -284,7 +284,6 @@ impl InteractiveCommand { } } } - } } // If no output was found, sleep briefly to avoid busy waiting diff --git a/src/commands/interactive/single_node.rs b/src/commands/interactive/single_node.rs index bfae7466..1235c787 100644 --- a/src/commands/interactive/single_node.rs +++ b/src/commands/interactive/single_node.rs @@ -15,14 +15,14 @@ //! Single node interactive session handling use anyhow::Result; +use rustyline::DefaultEditor; use rustyline::config::Configurer; use rustyline::error::ReadlineError; -use rustyline::DefaultEditor; use std::io::{self, Write}; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tokio::sync::mpsc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::sync::Mutex; +use tokio::sync::mpsc; use tokio::time::Duration; use super::super::interactive_signal::is_interrupted; diff --git a/src/commands/interactive/types.rs b/src/commands/interactive/types.rs index ab2d9a18..5631422f 100644 --- a/src/commands/interactive/types.rs +++ b/src/commands/interactive/types.rs @@ -15,8 +15,8 @@ //! Core types and structures for interactive mode use anyhow::Result; -use russh::client::Msg; use russh::Channel; +use russh::client::Msg; use std::path::PathBuf; use tokio::time::Duration; diff --git a/src/commands/interactive/utils.rs b/src/commands/interactive/utils.rs index eabe0b75..3748bdc7 100644 --- a/src/commands/interactive/utils.rs +++ b/src/commands/interactive/utils.rs @@ -49,16 +49,15 @@ impl InteractiveCommand { /// Expand ~ in path to home directory pub(super) fn expand_path(&self, path: &std::path::Path) -> Result { - if let Some(path_str) = path.to_str() { - if path_str.starts_with('~') { - if let Some(home) = dirs::home_dir() { - // Handle ~ alone or ~/path - if path_str == "~" { - return Ok(home); - } else if let Some(rest) = path_str.strip_prefix("~/") { - return Ok(home.join(rest)); - } - } + if let Some(path_str) = path.to_str() + && path_str.starts_with('~') + && let Some(home) = dirs::home_dir() + { + // Handle ~ alone or ~/path + if path_str == "~" { + return Ok(home); + } else if let Some(rest) = path_str.strip_prefix("~/") { + return Ok(home.join(rest)); } } Ok(path.to_path_buf()) diff --git a/src/commands/interactive_signal.rs b/src/commands/interactive_signal.rs index 855a4f16..78b086a7 100644 --- a/src/commands/interactive_signal.rs +++ b/src/commands/interactive_signal.rs @@ -15,8 +15,8 @@ //! Signal handling for interactive mode use anyhow::Result; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use tokio::signal; use tracing::{debug, info}; diff --git a/src/commands/upload.rs b/src/commands/upload.rs index 3e468ba5..bf35e10f 100644 --- a/src/commands/upload.rs +++ b/src/commands/upload.rs @@ -18,8 +18,8 @@ use std::path::Path; use crate::executor::ParallelExecutor; use crate::node::Node; -use crate::ssh::known_hosts::StrictHostKeyChecking; use crate::ssh::SshConfig; +use crate::ssh::known_hosts::StrictHostKeyChecking; use crate::ui::OutputFormatter; use crate::utils::fs::{format_bytes, resolve_source_files}; @@ -117,17 +117,17 @@ pub async fn upload_file( let remote_relative = relative_path.to_string_lossy(); // Create remote directory structure if needed - if let Some(parent) = relative_path.parent() { - if !parent.as_os_str().is_empty() { - let remote_dir = if validated_destination.ends_with('/') { - format!("{}{}", validated_destination, parent.display()) - } else { - format!("{}/{}", validated_destination, parent.display()) - }; - // Create remote directory using SSH command - let mkdir_cmd = format!("mkdir -p '{remote_dir}'"); - let _ = executor.execute(&mkdir_cmd).await; - } + if let Some(parent) = relative_path.parent() + && !parent.as_os_str().is_empty() + { + let remote_dir = if validated_destination.ends_with('/') { + format!("{}{}", validated_destination, parent.display()) + } else { + format!("{}/{}", validated_destination, parent.display()) + }; + // Create remote directory using SSH command + let mkdir_cmd = format!("mkdir -p '{remote_dir}'"); + let _ = executor.execute(&mkdir_cmd).await; } if validated_destination.ends_with('/') { diff --git a/src/config/interactive.rs b/src/config/interactive.rs index ea255db1..6b43665f 100644 --- a/src/config/interactive.rs +++ b/src/config/interactive.rs @@ -23,58 +23,57 @@ impl Config { pub fn get_interactive_config(&self, cluster_name: Option<&str>) -> InteractiveConfig { let mut config = self.interactive.clone(); - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(ref cluster_interactive) = cluster.interactive { - // Merge cluster-specific overrides with global config - // Cluster settings take precedence where specified - config.default_mode = cluster_interactive.default_mode.clone(); - - if !cluster_interactive.prompt_format.is_empty() { - config.prompt_format = cluster_interactive.prompt_format.clone(); - } - - if cluster_interactive.history_file.is_some() { - config.history_file = cluster_interactive.history_file.clone(); - } - - if cluster_interactive.work_dir.is_some() { - config.work_dir = cluster_interactive.work_dir.clone(); - } - - if cluster_interactive.broadcast_prefix.is_some() { - config.broadcast_prefix = cluster_interactive.broadcast_prefix.clone(); - } - - if cluster_interactive.node_switch_prefix.is_some() { - config.node_switch_prefix = cluster_interactive.node_switch_prefix.clone(); - } - - // Note: For booleans, we always use the cluster value since there's no "unset" state - config.show_timestamps = cluster_interactive.show_timestamps; - - // Merge colors (cluster colors override global ones) - for (k, v) in &cluster_interactive.colors { - config.colors.insert(k.clone(), v.clone()); - } - - // Merge keybindings - if !cluster_interactive.keybindings.switch_node.is_empty() { - config.keybindings.switch_node = - cluster_interactive.keybindings.switch_node.clone(); - } - if !cluster_interactive.keybindings.broadcast_toggle.is_empty() { - config.keybindings.broadcast_toggle = - cluster_interactive.keybindings.broadcast_toggle.clone(); - } - if !cluster_interactive.keybindings.quit.is_empty() { - config.keybindings.quit = cluster_interactive.keybindings.quit.clone(); - } - if cluster_interactive.keybindings.clear_screen.is_some() { - config.keybindings.clear_screen = - cluster_interactive.keybindings.clear_screen.clone(); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(ref cluster_interactive) = cluster.interactive + { + // Merge cluster-specific overrides with global config + // Cluster settings take precedence where specified + config.default_mode = cluster_interactive.default_mode.clone(); + + if !cluster_interactive.prompt_format.is_empty() { + config.prompt_format = cluster_interactive.prompt_format.clone(); + } + + if cluster_interactive.history_file.is_some() { + config.history_file = cluster_interactive.history_file.clone(); + } + + if cluster_interactive.work_dir.is_some() { + config.work_dir = cluster_interactive.work_dir.clone(); + } + + if cluster_interactive.broadcast_prefix.is_some() { + config.broadcast_prefix = cluster_interactive.broadcast_prefix.clone(); + } + + if cluster_interactive.node_switch_prefix.is_some() { + config.node_switch_prefix = cluster_interactive.node_switch_prefix.clone(); + } + + // Note: For booleans, we always use the cluster value since there's no "unset" state + config.show_timestamps = cluster_interactive.show_timestamps; + + // Merge colors (cluster colors override global ones) + for (k, v) in &cluster_interactive.colors { + config.colors.insert(k.clone(), v.clone()); + } + + // Merge keybindings + if !cluster_interactive.keybindings.switch_node.is_empty() { + config.keybindings.switch_node = + cluster_interactive.keybindings.switch_node.clone(); + } + if !cluster_interactive.keybindings.broadcast_toggle.is_empty() { + config.keybindings.broadcast_toggle = + cluster_interactive.keybindings.broadcast_toggle.clone(); + } + if !cluster_interactive.keybindings.quit.is_empty() { + config.keybindings.quit = cluster_interactive.keybindings.quit.clone(); + } + if cluster_interactive.keybindings.clear_screen.is_some() { + config.keybindings.clear_screen = + cluster_interactive.keybindings.clear_screen.clone(); } } diff --git a/src/config/resolver.rs b/src/config/resolver.rs index 55833aca..ea61ed2e 100644 --- a/src/config/resolver.rs +++ b/src/config/resolver.rs @@ -88,12 +88,11 @@ impl Config { /// Get SSH key for a cluster. pub fn get_ssh_key(&self, cluster_name: Option<&str>) -> Option { - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(key) = &cluster.defaults.ssh_key { - return Some(key.clone()); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(key) = &cluster.defaults.ssh_key + { + return Some(key.clone()); } self.defaults.ssh_key.clone() @@ -101,12 +100,11 @@ impl Config { /// Get timeout for a cluster. pub fn get_timeout(&self, cluster_name: Option<&str>) -> Option { - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(timeout) = cluster.defaults.timeout { - return Some(timeout); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(timeout) = cluster.defaults.timeout + { + return Some(timeout); } self.defaults.timeout @@ -114,12 +112,11 @@ impl Config { /// Get parallelism level for a cluster. pub fn get_parallel(&self, cluster_name: Option<&str>) -> Option { - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(parallel) = cluster.defaults.parallel { - return Some(parallel); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(parallel) = cluster.defaults.parallel + { + return Some(parallel); } self.defaults.parallel @@ -315,12 +312,11 @@ impl Config { cluster_name: Option<&str>, ssh_config: Option<&SshConfig>, ) -> Option<(String, Option)> { - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(jh) = &cluster.defaults.jump_host { - return self.process_jump_host_config(jh, ssh_config); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(jh) = &cluster.defaults.jump_host + { + return self.process_jump_host_config(jh, ssh_config); } // Fall back to global default self.defaults @@ -337,12 +333,11 @@ impl Config { /// /// Returns None if not specified (defaults will be applied at connection time). pub fn get_server_alive_interval(&self, cluster_name: Option<&str>) -> Option { - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(interval) = cluster.defaults.server_alive_interval { - return Some(interval); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(interval) = cluster.defaults.server_alive_interval + { + return Some(interval); } self.defaults.server_alive_interval } @@ -355,12 +350,11 @@ impl Config { /// /// Returns None if not specified (defaults will be applied at connection time). pub fn get_server_alive_count_max(&self, cluster_name: Option<&str>) -> Option { - if let Some(cluster_name) = cluster_name { - if let Some(cluster) = self.get_cluster(cluster_name) { - if let Some(count) = cluster.defaults.server_alive_count_max { - return Some(count); - } - } + if let Some(cluster_name) = cluster_name + && let Some(cluster) = self.get_cluster(cluster_name) + && let Some(count) = cluster.defaults.server_alive_count_max + { + return Some(count); } self.defaults.server_alive_count_max } diff --git a/src/config/utils.rs b/src/config/utils.rs index 2a18d998..8b1b0bc3 100644 --- a/src/config/utils.rs +++ b/src/config/utils.rs @@ -18,12 +18,11 @@ use std::path::{Path, PathBuf}; /// Expand tilde (~) in path to home directory. pub fn expand_tilde(path: &Path) -> PathBuf { - if let Some(path_str) = path.to_str() { - if path_str.starts_with("~/") { - if let Ok(home) = std::env::var("HOME") { - return PathBuf::from(path_str.replacen("~", &home, 1)); - } - } + if let Some(path_str) = path.to_str() + && path_str.starts_with("~/") + && let Ok(home) = std::env::var("HOME") + { + return PathBuf::from(path_str.replacen("~", &home, 1)); } path.to_path_buf() } diff --git a/src/executor/connection_manager.rs b/src/executor/connection_manager.rs index 9eb02f15..900bcd77 100644 --- a/src/executor/connection_manager.rs +++ b/src/executor/connection_manager.rs @@ -21,10 +21,10 @@ use std::sync::Arc; use crate::node::Node; use crate::security::SudoPassword; use crate::ssh::{ + SshClient, SshConfig, client::{CommandResult, ConnectionConfig}, known_hosts::StrictHostKeyChecking, tokio_client::SshConnectionConfig, - SshClient, SshConfig, }; /// Configuration for node execution. diff --git a/src/executor/execution_strategy.rs b/src/executor/execution_strategy.rs index 41f06b13..0c666e0e 100644 --- a/src/executor/execution_strategy.rs +++ b/src/executor/execution_strategy.rs @@ -24,7 +24,7 @@ use tokio::sync::Semaphore; use crate::node::Node; use super::connection_manager::{ - download_from_node, execute_on_node_with_jump_hosts, upload_to_node, ExecutionConfig, + ExecutionConfig, download_from_node, execute_on_node_with_jump_hosts, upload_to_node, }; use super::result_types::{DownloadResult, ExecutionResult, UploadResult}; diff --git a/src/executor/mod.rs b/src/executor/mod.rs index 164b52f7..e52803d2 100644 --- a/src/executor/mod.rs +++ b/src/executor/mod.rs @@ -28,7 +28,7 @@ pub mod rank_detector; // Re-export public types pub use connection_manager::download_dir_from_node; pub use exit_strategy::ExitCodeStrategy; -pub use output_mode::{is_tty, should_use_colors, OutputMode}; +pub use output_mode::{OutputMode, is_tty, should_use_colors}; pub use parallel::ParallelExecutor; pub use rank_detector::RankDetector; pub use result_types::{DownloadResult, ExecutionResult, UploadResult}; diff --git a/src/executor/output_mode.rs b/src/executor/output_mode.rs index 6f010651..90a113f1 100644 --- a/src/executor/output_mode.rs +++ b/src/executor/output_mode.rs @@ -198,10 +198,10 @@ pub fn should_use_colors() -> bool { } // Check TERM - if let Ok(term) = std::env::var("TERM") { - if term == "dumb" { - return false; - } + if let Ok(term) = std::env::var("TERM") + && term == "dumb" + { + return false; } true diff --git a/src/executor/parallel.rs b/src/executor/parallel.rs index 07eea415..6193b775 100644 --- a/src/executor/parallel.rs +++ b/src/executor/parallel.rs @@ -23,11 +23,11 @@ use tokio::sync::Semaphore; use crate::node::Node; use crate::security::SudoPassword; +use crate::ssh::SshConfig; use crate::ssh::known_hosts::StrictHostKeyChecking; use crate::ssh::tokio_client::SshConnectionConfig; -use crate::ssh::SshConfig; -use super::connection_manager::{download_from_node, ExecutionConfig}; +use super::connection_manager::{ExecutionConfig, download_from_node}; use super::execution_strategy::{ create_progress_style, download_file_task, execute_command_task, setup_download_progress_bar, setup_progress_bar, upload_file_task, @@ -1106,8 +1106,8 @@ impl ParallelExecutor { } use super::stream_manager::MultiNodeStreamManager; - use crate::ssh::client::ConnectionConfig; use crate::ssh::SshClient; + use crate::ssh::client::ConnectionConfig; use tokio::sync::mpsc; let semaphore = Arc::new(Semaphore::new(self.max_parallel)); @@ -1230,7 +1230,8 @@ impl ParallelExecutor { // Execute based on mode and ensure cleanup let no_prefix = output_mode.is_no_prefix(); - let result = if output_mode.is_tui() { + + if output_mode.is_tui() { // TUI mode: interactive terminal UI self.handle_tui_mode(&mut manager, handles, command).await } else if output_mode.is_stream() { @@ -1244,9 +1245,7 @@ impl ParallelExecutor { } else { // Fallback to normal mode self.execute(command).await - }; - - result + } } /// Handle stream mode output with optional [node] prefixes diff --git a/src/forwarding/dynamic/forwarder.rs b/src/forwarding/dynamic/forwarder.rs index a0ebab9b..81df23b0 100644 --- a/src/forwarding/dynamic/forwarder.rs +++ b/src/forwarding/dynamic/forwarder.rs @@ -13,7 +13,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::net::TcpListener; -use tokio::sync::{mpsc, Semaphore}; +use tokio::sync::{Semaphore, mpsc}; use tokio_util::sync::CancellationToken; use tracing::{error, info, warn}; use uuid::Uuid; @@ -54,7 +54,7 @@ impl DynamicForwarder { _ => { return Err(anyhow::anyhow!( "Invalid forwarding type for DynamicForwarder" - )) + )); } }; diff --git a/src/forwarding/local.rs b/src/forwarding/local.rs index 2038fdb6..4c6c4225 100644 --- a/src/forwarding/local.rs +++ b/src/forwarding/local.rs @@ -36,11 +36,11 @@ use super::{ use crate::ssh::tokio_client::Client; use anyhow::{Context, Result}; use std::net::SocketAddr; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::{mpsc, Semaphore}; +use tokio::sync::{Semaphore, mpsc}; use tokio::time::sleep; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn}; @@ -101,7 +101,7 @@ impl LocalForwarder { _ => { return Err(anyhow::anyhow!( "Invalid forwarding type for LocalForwarder" - )) + )); } }; @@ -535,9 +535,11 @@ mod tests { ); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Invalid forwarding type")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid forwarding type") + ); } } diff --git a/src/forwarding/manager.rs b/src/forwarding/manager.rs index 58ce82e6..9ae27d74 100644 --- a/src/forwarding/manager.rs +++ b/src/forwarding/manager.rs @@ -56,7 +56,7 @@ use anyhow::{Context, Result}; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock, mpsc}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use uuid::Uuid; diff --git a/src/forwarding/remote.rs b/src/forwarding/remote.rs index bec5ccaf..54ca6763 100644 --- a/src/forwarding/remote.rs +++ b/src/forwarding/remote.rs @@ -34,8 +34,8 @@ use super::{ use crate::ssh::tokio_client::Client; use anyhow::Result; use std::net::SocketAddr; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use tokio::sync::mpsc; use tokio::time::sleep; @@ -99,7 +99,7 @@ impl RemoteForwarder { _ => { return Err(anyhow::anyhow!( "Invalid forwarding type for RemoteForwarder" - )) + )); } }; @@ -534,9 +534,11 @@ mod tests { ); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Invalid forwarding type")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid forwarding type") + ); } } diff --git a/src/forwarding/spec.rs b/src/forwarding/spec.rs index 2ded9655..bd642850 100644 --- a/src/forwarding/spec.rs +++ b/src/forwarding/spec.rs @@ -21,7 +21,7 @@ //! let spec = ForwardingSpec::parse_dynamic("1080").unwrap(); //! ``` -use super::{parse_bind_spec, ForwardingType, SocksVersion}; +use super::{ForwardingType, SocksVersion, parse_bind_spec}; use anyhow::{Context, Result}; use std::net::{IpAddr, Ipv4Addr}; @@ -43,10 +43,12 @@ impl ForwardingSpec { match parts.len() { 3 => { // Format: port:host:hostport - let bind_port = parts[0].parse::() + let bind_port = parts[0] + .parse::() .with_context(|| format!("Invalid local port: {}", parts[0]))?; let remote_host = parts[1].to_string(); - let remote_port = parts[2].parse::() + let remote_port = parts[2] + .parse::() .with_context(|| format!("Invalid remote port: {}", parts[2]))?; Ok(ForwardingType::Local { @@ -61,7 +63,8 @@ impl ForwardingSpec { let bind_spec = format!("{}:{}", parts[0], parts[1]); let bind_addr = parse_bind_spec(&bind_spec)?; let remote_host = parts[2].to_string(); - let remote_port = parts[3].parse::() + let remote_port = parts[3] + .parse::() .with_context(|| format!("Invalid remote port: {}", parts[3]))?; Ok(ForwardingType::Local { @@ -90,10 +93,12 @@ impl ForwardingSpec { match parts.len() { 3 => { // Format: port:host:hostport - let bind_port = parts[0].parse::() + let bind_port = parts[0] + .parse::() .with_context(|| format!("Invalid remote port: {}", parts[0]))?; let local_host = parts[1].to_string(); - let local_port = parts[2].parse::() + let local_port = parts[2] + .parse::() .with_context(|| format!("Invalid local port: {}", parts[2]))?; Ok(ForwardingType::Remote { @@ -108,7 +113,8 @@ impl ForwardingSpec { let bind_spec = format!("{}:{}", parts[0], parts[1]); let bind_addr = parse_bind_spec(&bind_spec)?; let local_host = parts[2].to_string(); - let local_port = parts[3].parse::() + let local_port = parts[3] + .parse::() .with_context(|| format!("Invalid local port: {}", parts[3]))?; Ok(ForwardingType::Remote { diff --git a/src/forwarding/tunnel.rs b/src/forwarding/tunnel.rs index 4a519ee3..e6a2d10a 100644 --- a/src/forwarding/tunnel.rs +++ b/src/forwarding/tunnel.rs @@ -25,8 +25,8 @@ use crate::utils::buffer_pool::global; use anyhow::Result; use russh::Channel; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Instant; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; diff --git a/src/hostlist/expander.rs b/src/hostlist/expander.rs index ffc0c6fe..46022d98 100644 --- a/src/hostlist/expander.rs +++ b/src/hostlist/expander.rs @@ -18,7 +18,7 @@ //! using cartesian product for multiple range expressions. use super::error::HostlistError; -use super::parser::{parse_host_pattern, PatternSegment}; +use super::parser::{PatternSegment, parse_host_pattern}; /// Maximum number of hosts that can be generated from a single pattern const MAX_EXPANSION_SIZE: usize = 100_000; diff --git a/src/hostlist/mod.rs b/src/hostlist/mod.rs index c530bd61..3a4f298b 100644 --- a/src/hostlist/mod.rs +++ b/src/hostlist/mod.rs @@ -52,7 +52,7 @@ mod parser; pub use error::HostlistError; pub use expander::{expand_host_spec, expand_host_specs, expand_hostlist}; -pub use parser::{parse_host_pattern, parse_hostfile, HostPattern}; +pub use parser::{HostPattern, parse_host_pattern, parse_hostfile}; /// Check if a pattern is a hostlist expression (contains numeric range brackets) /// diff --git a/src/hostlist/parser.rs b/src/hostlist/parser.rs index 364acce9..f2e0554a 100644 --- a/src/hostlist/parser.rs +++ b/src/hostlist/parser.rs @@ -163,12 +163,12 @@ pub fn parse_host_pattern(pattern: &str) -> Result { } // Check for IPv6 literal (starts with digit or colon after [) - if let Some(&next_ch) = chars.peek() { - if is_ipv6_start(next_ch, &chars) { - // This might be an IPv6 literal, collect until matching ] - current_literal.push(ch); - continue; - } + if let Some(&next_ch) = chars.peek() + && is_ipv6_start(next_ch, &chars) + { + // This might be an IPv6 literal, collect until matching ] + current_literal.push(ch); + continue; } // Save any accumulated literal diff --git a/src/jump/chain.rs b/src/jump/chain.rs index 200d7b13..4d960b81 100644 --- a/src/jump/chain.rs +++ b/src/jump/chain.rs @@ -22,7 +22,7 @@ mod types; pub use types::{JumpConnection, JumpInfo}; use super::connection::JumpHostConnection; -use super::parser::{get_max_jump_hosts, JumpHost}; +use super::parser::{JumpHost, get_max_jump_hosts}; use super::rate_limiter::ConnectionRateLimiter; use crate::ssh::known_hosts::StrictHostKeyChecking; use crate::ssh::tokio_client::{AuthMethod, SshConnectionConfig}; diff --git a/src/jump/chain/auth.rs b/src/jump/chain/auth.rs index 66bc5780..4cfbf97d 100644 --- a/src/jump/chain/auth.rs +++ b/src/jump/chain/auth.rs @@ -419,11 +419,11 @@ pub(super) async fn authenticate_connection( ) .await; - if let Ok(auth_result) = result { - if auth_result.success() { - auth_success = true; - break; - } + if let Ok(auth_result) = result + && auth_result.success() + { + auth_success = true; + break; } } diff --git a/src/jump/mod.rs b/src/jump/mod.rs index b411ebf8..26d16331 100644 --- a/src/jump/mod.rs +++ b/src/jump/mod.rs @@ -42,4 +42,4 @@ pub mod rate_limiter; pub use chain::{JumpConnection, JumpHostChain}; pub use connection::JumpHostConnection; -pub use parser::{parse_jump_hosts, JumpHost}; +pub use parser::{JumpHost, parse_jump_hosts}; diff --git a/src/jump/parser/mod.rs b/src/jump/parser/mod.rs index 725e3351..8b3971c4 100644 --- a/src/jump/parser/mod.rs +++ b/src/jump/parser/mod.rs @@ -19,7 +19,7 @@ mod host; mod host_parser; mod main_parser; -pub use config::{get_max_jump_hosts, ABSOLUTE_MAX_JUMP_HOSTS, DEFAULT_MAX_JUMP_HOSTS}; +pub use config::{ABSOLUTE_MAX_JUMP_HOSTS, DEFAULT_MAX_JUMP_HOSTS, get_max_jump_hosts}; pub use host::JumpHost; pub use main_parser::parse_jump_hosts; diff --git a/src/keygen/ed25519.rs b/src/keygen/ed25519.rs index 93bc063a..c9859cf2 100644 --- a/src/keygen/ed25519.rs +++ b/src/keygen/ed25519.rs @@ -128,12 +128,14 @@ mod tests { let key = result.unwrap(); // Verify private key format - assert!(key - .private_key_pem - .contains("-----BEGIN OPENSSH PRIVATE KEY-----")); - assert!(key - .private_key_pem - .contains("-----END OPENSSH PRIVATE KEY-----")); + assert!( + key.private_key_pem + .contains("-----BEGIN OPENSSH PRIVATE KEY-----") + ); + assert!( + key.private_key_pem + .contains("-----END OPENSSH PRIVATE KEY-----") + ); // Verify public key format assert!(key.public_key_openssh.starts_with("ssh-ed25519 ")); diff --git a/src/keygen/mod.rs b/src/keygen/mod.rs index aac6b194..19e01f76 100644 --- a/src/keygen/mod.rs +++ b/src/keygen/mod.rs @@ -110,9 +110,10 @@ mod tests { assert!(result.is_ok()); let key = result.unwrap(); - assert!(key - .private_key_pem - .contains("-----BEGIN OPENSSH PRIVATE KEY-----")); + assert!( + key.private_key_pem + .contains("-----BEGIN OPENSSH PRIVATE KEY-----") + ); assert!(key.public_key_openssh.starts_with("ssh-ed25519 ")); assert!(key.public_key_openssh.contains("test@example.com")); assert!(key.fingerprint.starts_with("SHA256:")); @@ -133,9 +134,10 @@ mod tests { assert!(result.is_ok()); let key = result.unwrap(); - assert!(key - .private_key_pem - .contains("-----BEGIN OPENSSH PRIVATE KEY-----")); + assert!( + key.private_key_pem + .contains("-----BEGIN OPENSSH PRIVATE KEY-----") + ); assert!(key.public_key_openssh.starts_with("ssh-rsa ")); assert!(key.public_key_openssh.contains("test@example.com")); assert!(key.fingerprint.starts_with("SHA256:")); diff --git a/src/keygen/rsa.rs b/src/keygen/rsa.rs index 42f1fc07..55897d1a 100644 --- a/src/keygen/rsa.rs +++ b/src/keygen/rsa.rs @@ -24,7 +24,7 @@ //! RSA key generation is provided for compatibility with legacy systems. use super::GeneratedKey; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result, bail}; use russh::keys::{Algorithm, HashAlg, PrivateKey}; use ssh_key::LineEnding; use std::io::Write; @@ -167,12 +167,14 @@ mod tests { let key = result.unwrap(); // Verify private key format - assert!(key - .private_key_pem - .contains("-----BEGIN OPENSSH PRIVATE KEY-----")); - assert!(key - .private_key_pem - .contains("-----END OPENSSH PRIVATE KEY-----")); + assert!( + key.private_key_pem + .contains("-----BEGIN OPENSSH PRIVATE KEY-----") + ); + assert!( + key.private_key_pem + .contains("-----END OPENSSH PRIVATE KEY-----") + ); // Verify public key format assert!(key.public_key_openssh.starts_with("ssh-rsa ")); diff --git a/src/main.rs b/src/main.rs index 59804814..34e56fff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ use anyhow::Result; use bssh::cli::{ - has_pdsh_compat_flag, is_pdsh_compat_mode, remove_pdsh_compat_flag, Cli, Commands, PdshCli, + Cli, Commands, PdshCli, has_pdsh_compat_flag, is_pdsh_compat_mode, remove_pdsh_compat_flag, }; use bssh::hostlist; use clap::Parser; @@ -103,56 +103,57 @@ async fn handle_pdsh_query_mode(pdsh_cli: &PdshCli) -> Result<()> { .map_err(|e| anyhow::anyhow!("Failed to expand host expression: {e}"))?; // Process exclusion patterns (supports both glob patterns and hostlist expressions) - let (expanded_exclusions, glob_exclusions): (Vec, Vec) = - if let Some(ref exclude_str) = pdsh_cli.exclude { - let mut expanded = Vec::new(); - let mut globs = Vec::new(); - - for pattern in exclude_str.split(',').map(|s| s.trim()) { - // Security: Validate pattern length - const MAX_PATTERN_LENGTH: usize = 256; - if pattern.len() > MAX_PATTERN_LENGTH { - anyhow::bail!( - "Exclusion pattern too long (max {MAX_PATTERN_LENGTH} characters)" - ); - } + let (expanded_exclusions, glob_exclusions): (Vec, Vec) = if let Some( + ref exclude_str, + ) = + pdsh_cli.exclude + { + let mut expanded = Vec::new(); + let mut globs = Vec::new(); + + for pattern in exclude_str.split(',').map(|s| s.trim()) { + // Security: Validate pattern length + const MAX_PATTERN_LENGTH: usize = 256; + if pattern.len() > MAX_PATTERN_LENGTH { + anyhow::bail!( + "Exclusion pattern too long (max {MAX_PATTERN_LENGTH} characters)" + ); + } - // Security: Skip empty patterns - if pattern.is_empty() { - continue; - } + // Security: Skip empty patterns + if pattern.is_empty() { + continue; + } - // Check if it's a hostlist expression (contains numeric range brackets) - if hostlist::is_hostlist_expression(pattern) { - // Expand hostlist expression - let expanded_hosts = hostlist::expand_host_specs(pattern).map_err(|e| { - anyhow::anyhow!("Failed to expand exclusion pattern: {e}") - })?; - expanded.extend(expanded_hosts); - } else { - // Security: Prevent excessive wildcards for glob patterns - let wildcard_count = - pattern.chars().filter(|c| *c == '*' || *c == '?').count(); - const MAX_WILDCARDS: usize = 10; - if wildcard_count > MAX_WILDCARDS { - anyhow::bail!( + // Check if it's a hostlist expression (contains numeric range brackets) + if hostlist::is_hostlist_expression(pattern) { + // Expand hostlist expression + let expanded_hosts = hostlist::expand_host_specs(pattern) + .map_err(|e| anyhow::anyhow!("Failed to expand exclusion pattern: {e}"))?; + expanded.extend(expanded_hosts); + } else { + // Security: Prevent excessive wildcards for glob patterns + let wildcard_count = pattern.chars().filter(|c| *c == '*' || *c == '?').count(); + const MAX_WILDCARDS: usize = 10; + if wildcard_count > MAX_WILDCARDS { + anyhow::bail!( "Exclusion pattern contains too many wildcards (max {MAX_WILDCARDS})" ); - } + } - // Compile the glob pattern - match Pattern::new(pattern) { - Ok(p) => globs.push(p), - Err(_) => { - anyhow::bail!("Invalid exclusion pattern: {pattern}"); - } + // Compile the glob pattern + match Pattern::new(pattern) { + Ok(p) => globs.push(p), + Err(_) => { + anyhow::bail!("Invalid exclusion pattern: {pattern}"); } } } - (expanded, globs) - } else { - (Vec::new(), Vec::new()) - }; + } + (expanded, globs) + } else { + (Vec::new(), Vec::new()) + }; // Create a set for O(1) lookup of expanded exclusions let exclusion_set: std::collections::HashSet<&str> = diff --git a/src/pty/mod.rs b/src/pty/mod.rs index f13080f9..0561d525 100644 --- a/src/pty/mod.rs +++ b/src/pty/mod.rs @@ -19,10 +19,10 @@ //! and special keys. use anyhow::{Context, Result}; -use russh::{client::Msg, Channel}; +use russh::{Channel, client::Msg}; use signal_hook::{consts::SIGWINCH, iterator::Signals}; use smallvec::SmallVec; -use terminal_size::{terminal_size, Height, Width}; +use terminal_size::{Height, Width, terminal_size}; use tokio::sync::{mpsc, watch}; use tokio::time::Duration; @@ -30,7 +30,7 @@ pub mod session; pub mod terminal; pub use session::PtySession; -pub use terminal::{force_terminal_cleanup, TerminalState, TerminalStateGuard}; +pub use terminal::{TerminalState, TerminalStateGuard, force_terminal_cleanup}; /// Session processing interval for multiplex mode /// - 100ms provides reasonable time-slicing for multiplex mode diff --git a/src/pty/session/escape_filter.rs b/src/pty/session/escape_filter.rs index d6fdfe28..d6a9d058 100644 --- a/src/pty/session/escape_filter.rs +++ b/src/pty/session/escape_filter.rs @@ -134,20 +134,19 @@ impl EscapeSequenceFilter { let mut output = Vec::with_capacity(data.len()); // Check for timed-out incomplete sequences at the start of each filter call - if self.state != FilterState::Normal { - if let Some(start) = self.sequence_start { - if start.elapsed() > SEQUENCE_TIMEOUT { - tracing::trace!( - "Flushing timed-out escape sequence ({:?}): {:?}", - start.elapsed(), - String::from_utf8_lossy(&self.pending_buffer) - ); - output.extend_from_slice(&self.pending_buffer); - self.pending_buffer.clear(); - self.state = FilterState::Normal; - self.sequence_start = None; - } - } + if self.state != FilterState::Normal + && let Some(start) = self.sequence_start + && start.elapsed() > SEQUENCE_TIMEOUT + { + tracing::trace!( + "Flushing timed-out escape sequence ({:?}): {:?}", + start.elapsed(), + String::from_utf8_lossy(&self.pending_buffer) + ); + output.extend_from_slice(&self.pending_buffer); + self.pending_buffer.clear(); + self.state = FilterState::Normal; + self.sequence_start = None; } let mut i = 0; @@ -489,11 +488,7 @@ impl EscapeSequenceFilter { } // Return Some only if we parsed at least one digit - if idx > start { - Some(value) - } else { - None - } + if idx > start { Some(value) } else { None } } /// Reset the filter state. @@ -508,10 +503,10 @@ impl EscapeSequenceFilter { /// Returns true if there's an incomplete sequence that has timed out. #[allow(dead_code)] pub fn has_timed_out_sequence(&self) -> bool { - if self.state != FilterState::Normal { - if let Some(start) = self.sequence_start { - return start.elapsed() > SEQUENCE_TIMEOUT; - } + if self.state != FilterState::Normal + && let Some(start) = self.sequence_start + { + return start.elapsed() > SEQUENCE_TIMEOUT; } false } @@ -657,7 +652,7 @@ mod tests { // Create a malformed CSI sequence that exceeds MAX_CSI_SEQUENCE_SIZE (256 bytes) // without a proper terminator (no alphabetic character or ~) let mut malformed = vec![0x1b, b'[']; // ESC [ - // Add enough non-terminating bytes to exceed the limit + // Add enough non-terminating bytes to exceed the limit malformed.extend(std::iter::repeat_n(b';', 300)); // Keep adding parameter separators malformed.push(b'X'); // Finally add a terminator @@ -677,7 +672,7 @@ mod tests { // Create a DCS sequence that exceeds MAX_PENDING_SIZE (4096 bytes) // DCS sequences don't have the early termination, only the global limit applies let mut large_dcs = vec![0x1b, b'P']; // ESC P (DCS start) - // Add enough bytes to exceed the 4096 byte limit + // Add enough bytes to exceed the 4096 byte limit for i in 0..5000 { large_dcs.push(b'A' + (i % 26) as u8); } @@ -702,7 +697,7 @@ mod tests { let mut filter = EscapeSequenceFilter::new(); // Create a malformed CSI ? sequence that exceeds MAX_CSI_SEQUENCE_SIZE let mut malformed = vec![0x1b, b'[', b'?']; // ESC [ ? - // Add enough non-terminating bytes + // Add enough non-terminating bytes malformed.extend(std::iter::repeat_n(b'0', 300)); // Keep adding digits malformed.push(b'h'); // Finally add a terminator diff --git a/src/pty/session/raw_input.rs b/src/pty/session/raw_input.rs index 925167ec..3705a3c0 100644 --- a/src/pty/session/raw_input.rs +++ b/src/pty/session/raw_input.rs @@ -92,7 +92,7 @@ impl RawInputReader { /// } /// ``` pub fn poll(&self, timeout: Duration) -> io::Result { - use nix::poll::{poll, PollFd, PollFlags, PollTimeout}; + use nix::poll::{PollFd, PollFlags, PollTimeout, poll}; use std::os::unix::io::BorrowedFd; let fd = self.stdin.as_raw_fd(); diff --git a/src/pty/session/session_manager.rs b/src/pty/session/session_manager.rs index 2b59a9a8..e1b5de34 100644 --- a/src/pty/session/session_manager.rs +++ b/src/pty/session/session_manager.rs @@ -20,11 +20,11 @@ use super::local_escape::{LocalAction, LocalEscapeDetector}; use super::raw_input::RawInputReader; use super::terminal_modes::configure_terminal_modes; use crate::pty::{ - terminal::{TerminalOps, TerminalStateGuard}, PtyConfig, PtyMessage, PtyState, + terminal::{TerminalOps, TerminalStateGuard}, }; use anyhow::{Context, Result}; -use russh::{client::Msg, Channel, ChannelMsg}; +use russh::{Channel, ChannelMsg, client::Msg}; use std::io::{self, Write}; use tokio::sync::{mpsc, watch}; use tokio::time::Duration; @@ -194,15 +194,14 @@ impl PtySession { } signal_hook::consts::SIGWINCH // fallback, won't be reached } => { - if signal == signal_hook::consts::SIGWINCH { - if let Ok((width, height)) = crate::pty::utils::get_terminal_size() { + if signal == signal_hook::consts::SIGWINCH + && let Ok((width, height)) = crate::pty::utils::get_terminal_size() { // Try to send resize message, but don't block if channel is full if resize_tx.try_send(PtyMessage::Resize { width, height }).is_err() { // Channel full or closed, exit gracefully break; } } - } } // Handle cancellation diff --git a/src/pty/terminal.rs b/src/pty/terminal.rs index ba13f6a5..db94d853 100644 --- a/src/pty/terminal.rs +++ b/src/pty/terminal.rs @@ -22,8 +22,8 @@ use crossterm::{ }; use once_cell::sync::Lazy; use std::sync::{ - atomic::{AtomicBool, Ordering}, Arc, Mutex, + atomic::{AtomicBool, Ordering}, }; /// Global terminal cleanup synchronization diff --git a/src/security/mod.rs b/src/security/mod.rs index 7adf3f24..6cb140bc 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -35,6 +35,6 @@ pub use crate::shared::validation::{ // Re-export sudo password handling pub use sudo::{ - contains_sudo_failure, contains_sudo_prompt, get_sudo_password, prompt_sudo_password, - SudoPassword, SUDO_FAILURE_PATTERNS, SUDO_PROMPT_PATTERNS, + SUDO_FAILURE_PATTERNS, SUDO_PROMPT_PATTERNS, SudoPassword, contains_sudo_failure, + contains_sudo_prompt, get_sudo_password, prompt_sudo_password, }; diff --git a/src/server/audit/logstash.rs b/src/server/audit/logstash.rs index f8214b5c..525d4bad 100644 --- a/src/server/audit/logstash.rs +++ b/src/server/audit/logstash.rs @@ -64,9 +64,9 @@ use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; use tokio::sync::Mutex; +use tokio_rustls::TlsConnector; use tokio_rustls::rustls::pki_types::ServerName; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; -use tokio_rustls::TlsConnector; /// Represents a connection to the Logstash server, either plain TCP or TLS-encrypted. enum Connection { diff --git a/src/server/audit/otel.rs b/src/server/audit/otel.rs index ecbc15d9..2e6a951a 100644 --- a/src/server/audit/otel.rs +++ b/src/server/audit/otel.rs @@ -23,11 +23,11 @@ use super::exporter::AuditExporter; use anyhow::{Context, Result}; use async_trait::async_trait; use opentelemetry::{ - logs::{AnyValue, LogRecord as _, Logger, LoggerProvider as _, Severity}, KeyValue, + logs::{AnyValue, LogRecord as _, Logger, LoggerProvider as _, Severity}, }; use opentelemetry_otlp::{LogExporter, WithExportConfig}; -use opentelemetry_sdk::{logs::SdkLoggerProvider, Resource}; +use opentelemetry_sdk::{Resource, logs::SdkLoggerProvider}; use std::sync::Arc; use tokio::sync::RwLock; use tokio::sync::RwLockReadGuard; diff --git a/src/server/auth/composite.rs b/src/server/auth/composite.rs index 01b6f109..2e09a964 100644 --- a/src/server/auth/composite.rs +++ b/src/server/auth/composite.rs @@ -160,10 +160,10 @@ impl AuthProvider for CompositeAuthProvider { async fn get_user_info(&self, username: &str) -> Result> { // Try to get user info from password verifier first (has more detailed info) - if let Some(ref verifier) = self.password_verifier { - if let Some(info) = verifier.get_user_info(username).await? { - return Ok(Some(info)); - } + if let Some(ref verifier) = self.password_verifier + && let Some(info) = verifier.get_user_info(username).await? + { + return Ok(Some(info)); } // Fall back to public key verifier @@ -176,17 +176,17 @@ impl AuthProvider for CompositeAuthProvider { async fn user_exists(&self, username: &str) -> Result { // Check password verifier first - if let Some(ref verifier) = self.password_verifier { - if verifier.user_exists(username).await? { - return Ok(true); - } + if let Some(ref verifier) = self.password_verifier + && verifier.user_exists(username).await? + { + return Ok(true); } // Check public key verifier - if let Some(ref verifier) = self.publickey_verifier { - if verifier.user_exists(username).await? { - return Ok(true); - } + if let Some(ref verifier) = self.publickey_verifier + && verifier.user_exists(username).await? + { + return Ok(true); } Ok(false) diff --git a/src/server/auth/mod.rs b/src/server/auth/mod.rs index 511d86e9..285e9e93 100644 --- a/src/server/auth/mod.rs +++ b/src/server/auth/mod.rs @@ -103,7 +103,7 @@ pub mod provider; pub mod publickey; pub use composite::CompositeAuthProvider; -pub use password::{hash_password, verify_password_hash, PasswordAuthConfig, PasswordVerifier}; +pub use password::{PasswordAuthConfig, PasswordVerifier, hash_password, verify_password_hash}; pub use provider::AuthProvider; pub use publickey::{AuthKeyOptions, AuthorizedKey, PublicKeyAuthConfig, PublicKeyVerifier}; diff --git a/src/server/auth/password.rs b/src/server/auth/password.rs index 4b78a020..71bff4b8 100644 --- a/src/server/auth/password.rs +++ b/src/server/auth/password.rs @@ -55,8 +55,8 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use argon2::{ - password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier as _}, Algorithm, Argon2, Params, Version, + password_hash::{PasswordHash, PasswordHasher, PasswordVerifier as _, rand_core::OsRng}, }; use async_trait::async_trait; use russh::keys::ssh_key::PublicKey; @@ -576,10 +576,12 @@ mod tests { let verifier = PasswordVerifier::new(config).await.unwrap(); // Correct password should verify - assert!(verifier - .verify("testuser", "correct_password") - .await - .unwrap()); + assert!( + verifier + .verify("testuser", "correct_password") + .await + .unwrap() + ); // Incorrect password should not verify assert!(!verifier.verify("testuser", "wrong_password").await.unwrap()); @@ -604,10 +606,12 @@ mod tests { let verifier = PasswordVerifier::new(config).await.unwrap(); // bcrypt password should verify - assert!(verifier - .verify("bcryptuser", "bcrypt_password") - .await - .unwrap()); + assert!( + verifier + .verify("bcryptuser", "bcrypt_password") + .await + .unwrap() + ); // Wrong password should not verify assert!(!verifier.verify("bcryptuser", "wrong").await.unwrap()); diff --git a/src/server/auth/publickey.rs b/src/server/auth/publickey.rs index 43e9ac5d..ea3d3724 100644 --- a/src/server/auth/publickey.rs +++ b/src/server/auth/publickey.rs @@ -397,40 +397,40 @@ impl PublicKeyVerifier { } // SECURITY: Validate parent directory permissions - if let Some(parent) = path.parent() { - if let Ok(parent_metadata) = std::fs::symlink_metadata(parent) { - let parent_mode = parent_metadata.mode(); - - // Parent directory should not be world-writable or group-writable - if parent_mode & 0o002 != 0 { - anyhow::bail!( - "Parent directory {} of authorized_keys is world-writable (mode {:o})", - parent.display(), - parent_mode & 0o777 - ); - } - - if parent_mode & 0o020 != 0 { - tracing::warn!( - "Parent directory {} of authorized_keys is group-writable (mode {:o}). This is a potential security risk.", - parent.display(), - parent_mode & 0o777 - ); - } + if let Some(parent) = path.parent() + && let Ok(parent_metadata) = std::fs::symlink_metadata(parent) + { + let parent_mode = parent_metadata.mode(); + + // Parent directory should not be world-writable or group-writable + if parent_mode & 0o002 != 0 { + anyhow::bail!( + "Parent directory {} of authorized_keys is world-writable (mode {:o})", + parent.display(), + parent_mode & 0o777 + ); + } - // Check ownership - parent directory should be owned by same user as file - let file_uid = metadata.uid(); - let parent_uid = parent_metadata.uid(); + if parent_mode & 0o020 != 0 { + tracing::warn!( + "Parent directory {} of authorized_keys is group-writable (mode {:o}). This is a potential security risk.", + parent.display(), + parent_mode & 0o777 + ); + } - if file_uid != parent_uid { - tracing::warn!( - "authorized_keys file {} (uid: {}) and parent directory {} (uid: {}) have different owners", - path.display(), - file_uid, - parent.display(), - parent_uid - ); - } + // Check ownership - parent directory should be owned by same user as file + let file_uid = metadata.uid(); + let parent_uid = parent_metadata.uid(); + + if file_uid != parent_uid { + tracing::warn!( + "authorized_keys file {} (uid: {}) and parent directory {} (uid: {}) have different owners", + path.display(), + file_uid, + parent.display(), + parent_uid + ); } } @@ -744,8 +744,7 @@ mod tests { let verifier = PublicKeyVerifier::new(PublicKeyAuthConfig::default()); // Valid ed25519 key - let line = - "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl test@example"; + let line = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl test@example"; let result = verifier.parse_authorized_key_line(line); assert!(result.is_ok()); let key = result.unwrap(); diff --git a/src/server/config/loader.rs b/src/server/config/loader.rs index ca5e6f64..03421a4f 100644 --- a/src/server/config/loader.rs +++ b/src/server/config/loader.rs @@ -511,10 +511,12 @@ auth: let result = validate_config(&config); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("At least one host key")); + assert!( + result + .unwrap_err() + .to_string() + .contains("At least one host key") + ); } #[test] @@ -529,10 +531,12 @@ auth: let result = validate_config(&config); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("authentication method")); + assert!( + result + .unwrap_err() + .to_string() + .contains("authentication method") + ); } #[test] @@ -596,10 +600,12 @@ auth: let result = validate_config(&config); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("max_connections must be greater than 0")); + assert!( + result + .unwrap_err() + .to_string() + .contains("max_connections must be greater than 0") + ); } #[test] @@ -614,10 +620,12 @@ auth: let result = validate_config(&config); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Invalid bind_address")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Invalid bind_address") + ); } #[test] @@ -706,10 +714,12 @@ auth: let result = validate_config(&config); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Default shell does not exist")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Default shell does not exist") + ); } #[cfg(unix)] diff --git a/src/server/exec.rs b/src/server/exec.rs index 4d771424..4236e3ea 100644 --- a/src/server/exec.rs +++ b/src/server/exec.rs @@ -55,8 +55,8 @@ use std::time::Duration; use anyhow::{Context, Result}; use regex::Regex; -use russh::server::Handle; use russh::ChannelId; +use russh::server::Handle; use serde::{Deserialize, Serialize}; use tokio::io::AsyncReadExt; use tokio::process::Command; @@ -517,10 +517,10 @@ impl CommandExecutor { ]; for (pattern, description) in &dangerous_patterns { - if let Ok(re) = Regex::new(pattern) { - if re.is_match(command) { - anyhow::bail!("Command contains dangerous pattern ({})", description); - } + if let Ok(re) = Regex::new(pattern) + && re.is_match(command) + { + anyhow::bail!("Command contains dangerous pattern ({})", description); } } @@ -534,10 +534,10 @@ impl CommandExecutor { } // Also check if the first word matches (for command names) - if let Some(first_word) = normalized.split_whitespace().next() { - if first_word == blocked_normalized { - anyhow::bail!("Command '{first_word}' is blocked"); - } + if let Some(first_word) = normalized.split_whitespace().next() + && first_word == blocked_normalized + { + anyhow::bail!("Command '{first_word}' is blocked"); } } @@ -602,9 +602,11 @@ mod tests { assert_eq!(config.working_dir, Some(PathBuf::from("/tmp"))); assert_eq!(config.env.get("LANG"), Some(&"en_US.UTF-8".to_string())); assert!(config.allowed_commands.is_some()); - assert!(config - .blocked_commands - .contains(&"dangerous_cmd".to_string())); + assert!( + config + .blocked_commands + .contains(&"dangerous_cmd".to_string()) + ); } #[test] @@ -628,9 +630,11 @@ mod tests { assert!(executor.validate_command("rm -rf /").is_err()); assert!(executor.validate_command("rm -fr /home").is_err()); assert!(executor.validate_command("sudo mkfs /dev/sda").is_err()); - assert!(executor - .validate_command("dd if=/dev/zero of=/dev/sda") - .is_err()); + assert!( + executor + .validate_command("dd if=/dev/zero of=/dev/sda") + .is_err() + ); // Test command chaining attempts assert!(executor.validate_command("ls; rm -rf /").is_err()); @@ -661,18 +665,24 @@ mod tests { // Test disallowed commands assert!(executor.validate_command("rm -rf /").is_err()); - assert!(executor - .validate_command("wget http://example.com") - .is_err()); - assert!(executor - .validate_command("curl http://example.com") - .is_err()); + assert!( + executor + .validate_command("wget http://example.com") + .is_err() + ); + assert!( + executor + .validate_command("curl http://example.com") + .is_err() + ); // Test that command chaining is blocked even with allowed commands assert!(executor.validate_command("ls; rm -rf /").is_err()); - assert!(executor - .validate_command("cat /etc/passwd && rm -rf /") - .is_err()); + assert!( + executor + .validate_command("cat /etc/passwd && rm -rf /") + .is_err() + ); } #[test] diff --git a/src/server/filter/mod.rs b/src/server/filter/mod.rs index 9b38c997..44d327ca 100644 --- a/src/server/filter/mod.rs +++ b/src/server/filter/mod.rs @@ -61,8 +61,8 @@ use std::fmt; use std::path::Path; pub use self::path::{ - normalize_path, ComponentMatcher, ExactMatcher, ExtensionMatcher, MultiExtensionMatcher, - PrefixMatcher, SizeMatcher, + ComponentMatcher, ExactMatcher, ExtensionMatcher, MultiExtensionMatcher, PrefixMatcher, + SizeMatcher, normalize_path, }; pub use self::pattern::{ AllMatcher, CombinedMatcher, CompositeMatcher, GlobMatcher, NotMatcher, RegexMatcher, diff --git a/src/server/filter/path.rs b/src/server/filter/path.rs index 117ce0e3..9cb26f97 100644 --- a/src/server/filter/path.rs +++ b/src/server/filter/path.rs @@ -417,15 +417,15 @@ impl MultiExtensionMatcher { impl Matcher for MultiExtensionMatcher { fn matches(&self, path: &Path) -> bool { - if let Some(ext) = path.extension() { - if let Some(ext_str) = ext.to_str() { - let ext_cmp = if self.case_sensitive { - ext_str.to_string() - } else { - ext_str.to_lowercase() - }; - return self.extensions.contains(&ext_cmp); - } + if let Some(ext) = path.extension() + && let Some(ext_str) = ext.to_str() + { + let ext_cmp = if self.case_sensitive { + ext_str.to_string() + } else { + ext_str.to_lowercase() + }; + return self.extensions.contains(&ext_cmp); } false } @@ -504,15 +504,15 @@ impl SizeMatcher { /// Check if the given size matches. pub fn matches_size(&self, size: u64) -> bool { - if let Some(min) = self.min_size { - if size < min { - return false; - } + if let Some(min) = self.min_size + && size < min + { + return false; } - if let Some(max) = self.max_size { - if size > max { - return false; - } + if let Some(max) = self.max_size + && size > max + { + return false; } true } @@ -811,9 +811,11 @@ mod tests { assert!(matcher.pattern_description().contains("bat")); let case_sensitive = MultiExtensionMatcher::new(vec!["EXE"], true); - assert!(case_sensitive - .pattern_description() - .contains("case-sensitive")); + assert!( + case_sensitive + .pattern_description() + .contains("case-sensitive") + ); } // Tests for SizeMatcher diff --git a/src/server/filter/pattern.rs b/src/server/filter/pattern.rs index b55f34e1..1c2fb185 100644 --- a/src/server/filter/pattern.rs +++ b/src/server/filter/pattern.rs @@ -156,10 +156,10 @@ impl GlobMatcher { /// Check if the filename matches the pattern. fn matches_filename(&self, path: &Path) -> bool { - if let Some(filename) = path.file_name() { - if let Some(filename_str) = filename.to_str() { - return self.pattern.matches(filename_str); - } + if let Some(filename) = path.file_name() + && let Some(filename_str) = filename.to_str() + { + return self.pattern.matches(filename_str); } false } @@ -800,7 +800,7 @@ mod tests { let matcher = GlobMatcher::with_mode("*.key", GlobMatchMode::FullPathOnly).unwrap(); assert!(matcher.matches(Path::new("secret.key"))); // Direct match - // * in glob matches path separators too, so this actually matches + // * in glob matches path separators too, so this actually matches assert!(matcher.matches(Path::new("/etc/secret.key"))); } diff --git a/src/server/filter/policy.rs b/src/server/filter/policy.rs index 2a544c60..f3bf2519 100644 --- a/src/server/filter/policy.rs +++ b/src/server/filter/policy.rs @@ -28,7 +28,7 @@ use crate::server::config::{ CompositeLogicType, FilterAction, FilterConfig, FilterRule as FilterRuleConfig, MatcherConfig, }; use crate::server::filter::path::{ - normalize_path, ComponentMatcher, MultiExtensionMatcher, PrefixMatcher, + ComponentMatcher, MultiExtensionMatcher, PrefixMatcher, normalize_path, }; use crate::server::filter::pattern::{AllMatcher, CombinedMatcher, GlobMatcher, NotMatcher}; @@ -343,7 +343,9 @@ impl FilterPolicy { } else if let Some(ref directory) = config.directory { Ok(Box::new(ComponentMatcher::new(directory.as_str()))) } else { - anyhow::bail!("Matcher config must have one of: 'pattern', 'path_prefix', 'extensions', 'directory', or 'not'") + anyhow::bail!( + "Matcher config must have one of: 'pattern', 'path_prefix', 'extensions', 'directory', or 'not'" + ) } } } diff --git a/src/server/handler.rs b/src/server/handler.rs index 1cc59140..531694e4 100644 --- a/src/server/handler.rs +++ b/src/server/handler.rs @@ -374,24 +374,23 @@ impl russh::server::Handler for SshHandler { let public_key = public_key.clone(); // Get mutable reference to session_info for authentication update - let session_info = &mut self.session_info; + let mut session_info = &mut self.session_info; async move { // Check if IP is banned (fail2ban-like check) - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - if limiter.is_banned(&ip).await { - tracing::warn!( - user = %user, - peer = ?peer_addr, - "Rejected auth from banned IP" - ); - return Ok(Auth::Reject { - proceed_with_methods: None, - partial_success: false, - }); - } - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + && limiter.is_banned(&ip).await + { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "Rejected auth from banned IP" + ); + return Ok(Auth::Reject { + proceed_with_methods: None, + partial_success: false, + }); } if exceeded { @@ -447,13 +446,13 @@ impl russh::server::Handler for SshHandler { ); // Try to authenticate session with per-user limits - if let Some(ref info) = session_info { + if let Some(info) = &session_info { let mut sessions_guard = sessions.write().await; match sessions_guard.authenticate_session(info.id, &user) { Ok(()) => { // Also update local session info drop(sessions_guard); - if let Some(ref mut local_info) = session_info { + if let Some(local_info) = &mut session_info { local_info.authenticate(&user); } } @@ -484,10 +483,10 @@ impl russh::server::Handler for SshHandler { } // Record success to reset failure counter - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - limiter.record_success(&ip).await; - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + { + limiter.record_success(&ip).await; } Ok(Auth::Accept) @@ -501,16 +500,16 @@ impl russh::server::Handler for SshHandler { ); // Record failure for ban tracking - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - let banned = limiter.record_failure(ip).await; - if banned { - tracing::warn!( - user = %user, - peer = ?peer_addr, - "IP banned due to too many failed auth attempts" - ); - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + { + let banned = limiter.record_failure(ip).await; + if banned { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "IP banned due to too many failed auth attempts" + ); } } @@ -534,10 +533,10 @@ impl russh::server::Handler for SshHandler { ); // Record failure for ban tracking - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - limiter.record_failure(ip).await; - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + { + limiter.record_failure(ip).await; } let proceed = if methods.is_empty() { @@ -592,24 +591,23 @@ impl russh::server::Handler for SshHandler { let allow_password = self.config.allow_password_auth; // Get mutable reference to session_info for authentication update - let session_info = &mut self.session_info; + let mut session_info = &mut self.session_info; async move { // Check if IP is banned (fail2ban-like check) - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - if limiter.is_banned(&ip).await { - tracing::warn!( - user = %user, - peer = ?peer_addr, - "Rejected password auth from banned IP" - ); - return Ok(Auth::Reject { - proceed_with_methods: None, - partial_success: false, - }); - } - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + && limiter.is_banned(&ip).await + { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "Rejected password auth from banned IP" + ); + return Ok(Auth::Reject { + proceed_with_methods: None, + partial_success: false, + }); } // Check if password auth is enabled @@ -681,13 +679,13 @@ impl russh::server::Handler for SshHandler { ); // Try to authenticate session with per-user limits - if let Some(ref info) = session_info { + if let Some(info) = &session_info { let mut sessions_guard = sessions.write().await; match sessions_guard.authenticate_session(info.id, &user) { Ok(()) => { // Also update local session info drop(sessions_guard); - if let Some(ref mut local_info) = session_info { + if let Some(local_info) = &mut session_info { local_info.authenticate(&user); } } @@ -718,10 +716,10 @@ impl russh::server::Handler for SshHandler { } // Record success to reset failure counter - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - limiter.record_success(&ip).await; - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + { + limiter.record_success(&ip).await; } Ok(Auth::Accept) @@ -734,16 +732,16 @@ impl russh::server::Handler for SshHandler { ); // Record failure for ban tracking - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - let banned = limiter.record_failure(ip).await; - if banned { - tracing::warn!( - user = %user, - peer = ?peer_addr, - "IP banned due to too many failed password auth attempts" - ); - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + { + let banned = limiter.record_failure(ip).await; + if banned { + tracing::warn!( + user = %user, + peer = ?peer_addr, + "IP banned due to too many failed password auth attempts" + ); } } @@ -767,10 +765,10 @@ impl russh::server::Handler for SshHandler { ); // Record failure for ban tracking - if let Some(ref limiter) = auth_rate_limiter { - if let Some(ip) = peer_addr.map(|a| a.ip()) { - limiter.record_failure(ip).await; - } + if let Some(ref limiter) = auth_rate_limiter + && let Some(ip) = peer_addr.map(|a| a.ip()) + { + limiter.record_failure(ip).await; } let proceed = if methods.is_empty() { @@ -1396,13 +1394,13 @@ impl russh::server::Handler for SshHandler { ); // Update stored PTY config - if let Some(state) = self.channels.get_mut(&channel_id) { - if let Some(ref mut pty) = state.pty { - pty.col_width = col_width; - pty.row_height = row_height; - pty.pix_width = pix_width; - pty.pix_height = pix_height; - } + if let Some(state) = self.channels.get_mut(&channel_id) + && let Some(ref mut pty) = state.pty + { + pty.col_width = col_width; + pty.row_height = row_height; + pty.pix_width = pix_width; + pty.pix_height = pix_height; } // Get the PTY mutex if there's an active shell session diff --git a/src/server/mod.rs b/src/server/mod.rs index 8a2d1ba1..90ec64bb 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -341,18 +341,17 @@ impl russh::server::Server for BsshServerRunner { // Use try_read to avoid blocking in sync context if let Ok(is_banned) = tokio::runtime::Handle::try_current() .map(|h| h.block_on(self.auth_rate_limiter.is_banned(&ip))) + && is_banned { - if is_banned { - tracing::info!( - ip = %ip, - "Connection rejected from banned IP" - ); - return SshHandler::rejected( - peer_addr, - Arc::clone(&self.config), - Arc::clone(&self.sessions), - ); - } + tracing::info!( + ip = %ip, + "Connection rejected from banned IP" + ); + return SshHandler::rejected( + peer_addr, + Arc::clone(&self.config), + Arc::clone(&self.sessions), + ); } } diff --git a/src/server/pty.rs b/src/server/pty.rs index f58665c3..a61a8378 100644 --- a/src/server/pty.rs +++ b/src/server/pty.rs @@ -38,7 +38,7 @@ use std::path::PathBuf; use anyhow::{Context, Result}; use nix::libc; -use nix::pty::{openpty, OpenptyResult, Winsize}; +use nix::pty::{OpenptyResult, Winsize, openpty}; use nix::unistd; use tokio::io::unix::AsyncFd; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; diff --git a/src/server/scp.rs b/src/server/scp.rs index cf71476d..5ff02949 100644 --- a/src/server/scp.rs +++ b/src/server/scp.rs @@ -49,8 +49,8 @@ use std::os::unix::fs::PermissionsExt; use std::path::{Component, Path, PathBuf}; use anyhow::{Context, Result}; -use russh::server::Handle; use russh::ChannelId; +use russh::server::Handle; use tokio::fs::{self, File, OpenOptions}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::mpsc; @@ -504,10 +504,10 @@ impl ScpHandler { } b'T' => { // Preserve times: T 0 0 - if self.preserve_times { - if let Err(e) = self.parse_times(&line) { - tracing::warn!("Error parsing times: {}", e); - } + if self.preserve_times + && let Err(e) = self.parse_times(&line) + { + tracing::warn!("Error parsing times: {}", e); } self.send_ok(channel_id, &handle).await?; } diff --git a/src/server/security/rate_limit.rs b/src/server/security/rate_limit.rs index aea6bfff..5ea4997c 100644 --- a/src/server/security/rate_limit.rs +++ b/src/server/security/rate_limit.rs @@ -155,10 +155,10 @@ impl AuthRateLimiter { } let bans = self.bans.read().await; - if let Some(expiry) = bans.get(ip) { - if Instant::now() < *expiry { - return true; - } + if let Some(expiry) = bans.get(ip) + && Instant::now() < *expiry + { + return true; } false } diff --git a/src/server/session.rs b/src/server/session.rs index f9e8cbd7..016be202 100644 --- a/src/server/session.rs +++ b/src/server/session.rs @@ -59,14 +59,14 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; use russh::server::Msg; use russh::{Channel, ChannelId}; use thiserror::Error; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{RwLock, mpsc}; use super::pty::PtyMaster; @@ -170,13 +170,13 @@ impl SessionConfig { ); } - if let Some(session_timeout) = self.session_timeout { - if session_timeout < self.idle_timeout { - warnings.push(format!( + if let Some(session_timeout) = self.session_timeout + && session_timeout < self.idle_timeout + { + warnings.push(format!( "session_timeout ({:?}) < idle_timeout ({:?}) - sessions may be terminated before idle check", session_timeout, self.idle_timeout )); - } } warnings @@ -692,12 +692,12 @@ impl SessionManager { // Remove from user sessions tracking if let Some(ref session) = session { - if let Some(ref username) = session.user { - if let Some(user_sessions) = self.user_sessions.get_mut(username) { - user_sessions.retain(|&sid| sid != id); - if user_sessions.is_empty() { - self.user_sessions.remove(username); - } + if let Some(ref username) = session.user + && let Some(user_sessions) = self.user_sessions.get_mut(username) + { + user_sessions.retain(|&sid| sid != id); + if user_sessions.is_empty() { + self.user_sessions.remove(username); } } @@ -759,10 +759,10 @@ impl SessionManager { return Some(*id); } // Check session timeout - if let Some(max_duration) = self.config.session_timeout { - if info.is_expired(max_duration) { - return Some(*id); - } + if let Some(max_duration) = self.config.session_timeout + && info.is_expired(max_duration) + { + return Some(*id); } None }) diff --git a/src/server/sftp.rs b/src/server/sftp.rs index 16cad4ae..ea80e278 100644 --- a/src/server/sftp.rs +++ b/src/server/sftp.rs @@ -400,30 +400,30 @@ impl russh_sftp::server::Handler for SftpHandler { // Check if the path is a symlink and validate the target let metadata = fs::symlink_metadata(&path).await; - if let Ok(meta) = metadata { - if meta.is_symlink() { - // Follow the symlink and ensure target is within root - let target = fs::read_link(&path).await?; - let resolved_target = if target.is_absolute() { - target - } else { - // Resolve relative symlink from the symlink's directory - let base = path.parent().unwrap_or(&root_dir); - let joined = base.join(&target); - // Use tokio's canonicalize for async operation - tokio::fs::canonicalize(&joined).await.unwrap_or(target) - }; - - if !resolved_target.starts_with(&root_dir) { - tracing::warn!( - path = %path.display(), - target = %resolved_target.display(), - "Symlink target outside root directory" - ); - return Err(SftpError::permission_denied( - "Symlink target outside allowed directory", - )); - } + if let Ok(meta) = metadata + && meta.is_symlink() + { + // Follow the symlink and ensure target is within root + let target = fs::read_link(&path).await?; + let resolved_target = if target.is_absolute() { + target + } else { + // Resolve relative symlink from the symlink's directory + let base = path.parent().unwrap_or(&root_dir); + let joined = base.join(&target); + // Use tokio's canonicalize for async operation + tokio::fs::canonicalize(&joined).await.unwrap_or(target) + }; + + if !resolved_target.starts_with(&root_dir) { + tracing::warn!( + path = %path.display(), + target = %resolved_target.display(), + "Symlink target outside root directory" + ); + return Err(SftpError::permission_denied( + "Symlink target outside allowed directory", + )); } } diff --git a/src/server/shell.rs b/src/server/shell.rs index d8877b99..86c21354 100644 --- a/src/server/shell.rs +++ b/src/server/shell.rs @@ -41,7 +41,7 @@ use russh::server::{Handle, Msg}; use russh::{ChannelId, ChannelStream}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Child; -use tokio::sync::{mpsc, RwLock}; +use tokio::sync::{RwLock, mpsc}; use super::pty::{PtyConfig, PtyMaster}; use crate::shared::auth_types::UserInfo; @@ -428,7 +428,7 @@ async fn drain_pty_output_to_stream( /// Wait for child process to exit and return exit code. async fn wait_for_child(child: &mut Option) -> i32 { - if let Some(ref mut c) = child { + if let Some(c) = child { match c.wait().await { Ok(status) => status.code().unwrap_or(1), Err(e) => { diff --git a/src/shared/rate_limit.rs b/src/shared/rate_limit.rs index acff82ae..1c5a4790 100644 --- a/src/shared/rate_limit.rs +++ b/src/shared/rate_limit.rs @@ -41,7 +41,7 @@ //! // For user IDs: RateLimiter //! ``` -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; diff --git a/src/shared/validation.rs b/src/shared/validation.rs index d8124c34..005a69f9 100644 --- a/src/shared/validation.rs +++ b/src/shared/validation.rs @@ -412,12 +412,12 @@ pub fn sanitize_error_message(message: &str) -> String { let mut sanitized = message.to_string(); // Remove specific usernames (format: user 'username') - if let Some(start) = sanitized.find("user '") { - if let Some(end) = sanitized[start + 6..].find('\'') { - let before = &sanitized[..start + 5]; - let after = &sanitized[start + 6 + end + 1..]; - sanitized = format!("{before}{after}"); - } + if let Some(start) = sanitized.find("user '") + && let Some(end) = sanitized[start + 6..].find('\'') + { + let before = &sanitized[..start + 5]; + let after = &sanitized[start + 6 + end + 1..]; + sanitized = format!("{before}{after}"); } // Remove hostname:port combinations in common patterns diff --git a/src/ssh/auth.rs b/src/ssh/auth.rs index 4aa1259a..d2777aff 100644 --- a/src/ssh/auth.rs +++ b/src/ssh/auth.rs @@ -256,10 +256,10 @@ impl AuthContext { } // Priority 2: SSH agent (explicit request) - if self.use_agent { - if let Some(auth) = self.agent_auth()? { - return Ok(auth); - } + if self.use_agent + && let Some(auth) = self.agent_auth()? + { + return Ok(auth); } // Priority 3: Key file authentication (explicit -i flag) @@ -290,7 +290,9 @@ impl AuthContext { // If allow_password_fallback is set (interactive mode), skip consent prompt // Otherwise, ask for explicit user consent for security let should_attempt_password = if self.allow_password_fallback { - tracing::info!("SSH key authentication failed, falling back to password authentication"); + tracing::info!( + "SSH key authentication failed, falling back to password authentication" + ); // SECURITY: Add rate limiting before password fallback to prevent rapid attempts const FALLBACK_DELAY: Duration = Duration::from_secs(1); diff --git a/src/ssh/client/connection.rs b/src/ssh/client/connection.rs index c5ae8ace..b5abb5a1 100644 --- a/src/ssh/client/connection.rs +++ b/src/ssh/client/connection.rs @@ -13,7 +13,7 @@ // limitations under the License. use super::core::SshClient; -use crate::jump::{parse_jump_hosts, JumpHostChain}; +use crate::jump::{JumpHostChain, parse_jump_hosts}; use crate::ssh::known_hosts::StrictHostKeyChecking; use crate::ssh::tokio_client::{AuthMethod, Client, SshConnectionConfig}; use anyhow::{Context, Result}; diff --git a/src/ssh/config_cache/maintenance.rs b/src/ssh/config_cache/maintenance.rs index 098c9fe9..3b9ce6f0 100644 --- a/src/ssh/config_cache/maintenance.rs +++ b/src/ssh/config_cache/maintenance.rs @@ -65,11 +65,11 @@ impl SshConfigCache { // Wait for all file checks to complete for task in check_tasks { - if let Ok((path, is_stale, _file_exists)) = task.await { - if is_stale { - to_remove.push(path); - stale_count += 1; - } + if let Ok((path, is_stale, _file_exists)) = task.await + && is_stale + { + to_remove.push(path); + stale_count += 1; } } diff --git a/src/ssh/keychain_macos.rs b/src/ssh/keychain_macos.rs index c279e968..ec0c6ce8 100644 --- a/src/ssh/keychain_macos.rs +++ b/src/ssh/keychain_macos.rs @@ -54,7 +54,7 @@ use anyhow::{Context, Result}; use security_framework::passwords::{ - delete_generic_password, generic_password, set_generic_password, PasswordOptions, + PasswordOptions, delete_generic_password, generic_password, set_generic_password, }; use std::path::Path; use zeroize::Zeroizing; diff --git a/src/ssh/mod.rs b/src/ssh/mod.rs index 67a33f83..adfc4cfb 100644 --- a/src/ssh/mod.rs +++ b/src/ssh/mod.rs @@ -26,7 +26,7 @@ pub mod keychain_macos; pub use auth::AuthContext; pub use client::SshClient; -pub use config_cache::{CacheConfig, CacheStats, SshConfigCache, GLOBAL_CACHE}; +pub use config_cache::{CacheConfig, CacheStats, GLOBAL_CACHE, SshConfigCache}; pub use handler::BsshHandler; pub use pool::ConnectionPool; pub use ssh_config::{SshConfig, SshHostConfig}; diff --git a/src/ssh/ssh_config/env_cache/tests.rs b/src/ssh/ssh_config/env_cache/tests.rs index 24edf93e..c68b5a07 100644 --- a/src/ssh/ssh_config/env_cache/tests.rs +++ b/src/ssh/ssh_config/env_cache/tests.rs @@ -15,8 +15,8 @@ use super::cache::EnvironmentCache; use super::config::EnvCacheConfig; use super::entry::CacheEntry; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; #[test] diff --git a/src/ssh/ssh_config/include/mod.rs b/src/ssh/ssh_config/include/mod.rs index e27c7be8..9176083a 100644 --- a/src/ssh/ssh_config/include/mod.rs +++ b/src/ssh/ssh_config/include/mod.rs @@ -454,27 +454,33 @@ mod tests { assert_eq!(result.len(), 3); // Check lexical ordering of included files - assert!(result[0] - .path - .file_name() - .unwrap() - .to_str() - .unwrap() - .contains("01-first")); - assert!(result[1] - .path - .file_name() - .unwrap() - .to_str() - .unwrap() - .contains("02-second")); - assert!(result[2] - .path - .file_name() - .unwrap() - .to_str() - .unwrap() - .contains("03-third")); + assert!( + result[0] + .path + .file_name() + .unwrap() + .to_str() + .unwrap() + .contains("01-first") + ); + assert!( + result[1] + .path + .file_name() + .unwrap() + .to_str() + .unwrap() + .contains("02-second") + ); + assert!( + result[2] + .path + .file_name() + .unwrap() + .to_str() + .unwrap() + .contains("03-third") + ); } #[tokio::test] diff --git a/src/ssh/ssh_config/include/resolver.rs b/src/ssh/ssh_config/include/resolver.rs index ba737a2b..7c6c5e73 100644 --- a/src/ssh/ssh_config/include/resolver.rs +++ b/src/ssh/ssh_config/include/resolver.rs @@ -231,24 +231,30 @@ mod tests { // Should have 3 files in lexical order assert_eq!(files.len(), 3); - assert!(files[0] - .file_name() - .unwrap() - .to_str() - .unwrap() - .contains("01-first")); - assert!(files[1] - .file_name() - .unwrap() - .to_str() - .unwrap() - .contains("02-second")); - assert!(files[2] - .file_name() - .unwrap() - .to_str() - .unwrap() - .contains("03-third")); + assert!( + files[0] + .file_name() + .unwrap() + .to_str() + .unwrap() + .contains("01-first") + ); + assert!( + files[1] + .file_name() + .unwrap() + .to_str() + .unwrap() + .contains("02-second") + ); + assert!( + files[2] + .file_name() + .unwrap() + .to_str() + .unwrap() + .contains("03-third") + ); } #[tokio::test] diff --git a/src/ssh/ssh_config/include/validation.rs b/src/ssh/ssh_config/include/validation.rs index 63e0d03d..ccd9f0f1 100644 --- a/src/ssh/ssh_config/include/validation.rs +++ b/src/ssh/ssh_config/include/validation.rs @@ -160,10 +160,12 @@ mod tests { // Test too many wildcards let result = validate_glob_pattern("a*/b*/c*/d*/e*/f*"); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Too many wildcards")); + assert!( + result + .unwrap_err() + .to_string() + .contains("Too many wildcards") + ); // Test too-long pattern let long_pattern = "a".repeat(600); diff --git a/src/ssh/ssh_config/integration_tests/certificate_forwarding_integration_test.rs b/src/ssh/ssh_config/integration_tests/certificate_forwarding_integration_test.rs index 682c22e9..15a32b16 100644 --- a/src/ssh/ssh_config/integration_tests/certificate_forwarding_integration_test.rs +++ b/src/ssh/ssh_config/integration_tests/certificate_forwarding_integration_test.rs @@ -70,12 +70,16 @@ Host web.prod.example.com // Test resolution for web.prod.example.com (should get certs from included file) let resolved = config.find_host_config("web.prod.example.com"); assert_eq!(resolved.certificate_files.len(), 2); - assert!(resolved.certificate_files[0] - .to_string_lossy() - .contains("prod-user-cert.pub")); - assert!(resolved.certificate_files[1] - .to_string_lossy() - .contains("prod-host-cert.pub")); + assert!( + resolved.certificate_files[0] + .to_string_lossy() + .contains("prod-user-cert.pub") + ); + assert!( + resolved.certificate_files[1] + .to_string_lossy() + .contains("prod-host-cert.pub") + ); assert_eq!(resolved.ca_signature_algorithms.len(), 2); assert_eq!(resolved.user, Some("webuser".to_string())); assert_eq!(resolved.port, Some(22)); @@ -155,9 +159,11 @@ Host web.prod.example.com // Match block should have certificate options assert_eq!(config.hosts[0].certificate_files.len(), 1); - assert!(config.hosts[0].certificate_files[0] - .to_string_lossy() - .contains("admin-cert.pub")); + assert!( + config.hosts[0].certificate_files[0] + .to_string_lossy() + .contains("admin-cert.pub") + ); assert_eq!(config.hosts[0].ca_signature_algorithms.len(), 1); assert_eq!(config.hosts[0].hostbased_authentication, Some(true)); } @@ -325,9 +331,11 @@ Host web1.prod.example.com // - Permit remote open from Match block // - User from specific Host block assert_eq!(resolved.certificate_files.len(), 1); - assert!(resolved.certificate_files[0] - .to_string_lossy() - .contains("default-cert.pub")); + assert!( + resolved.certificate_files[0] + .to_string_lossy() + .contains("default-cert.pub") + ); assert_eq!(resolved.ca_signature_algorithms.len(), 2); assert_eq!(resolved.hostbased_authentication, Some(true)); // Overridden assert_eq!(resolved.gateway_ports, Some("clientspecified".to_string())); diff --git a/src/ssh/ssh_config/match_directive/exec.rs b/src/ssh/ssh_config/match_directive/exec.rs index 349fd22d..e98a2148 100644 --- a/src/ssh/ssh_config/match_directive/exec.rs +++ b/src/ssh/ssh_config/match_directive/exec.rs @@ -226,13 +226,13 @@ pub fn validate_exec_command(command: &str) -> Result<()> { } '$' if !in_single_quote => { // $ is dangerous in double quotes or unquoted - if let Some(next) = command.chars().nth(command.find('$').unwrap() + 1) { - if next == '(' || next == '{' { - anyhow::bail!( - "Match exec command contains potential command or variable substitution. \ + if let Some(next) = command.chars().nth(command.find('$').unwrap() + 1) + && (next == '(' || next == '{') + { + anyhow::bail!( + "Match exec command contains potential command or variable substitution. \ This is blocked for security." - ); - } + ); } } _ => {} diff --git a/src/ssh/ssh_config/mod.rs b/src/ssh/ssh_config/mod.rs index aa6b96e6..779ce032 100644 --- a/src/ssh/ssh_config/mod.rs +++ b/src/ssh/ssh_config/mod.rs @@ -390,9 +390,11 @@ Host web1.secure.example.com // Verify first host (*.secure.example.com) let host1 = &config.hosts[0]; assert_eq!(host1.certificate_files.len(), 1); - assert!(host1.certificate_files[0] - .to_string_lossy() - .contains("id_rsa-cert.pub")); + assert!( + host1.certificate_files[0] + .to_string_lossy() + .contains("id_rsa-cert.pub") + ); assert_eq!(host1.ca_signature_algorithms.len(), 2); assert_eq!(host1.ca_signature_algorithms[0], "ssh-ed25519"); assert_eq!(host1.ca_signature_algorithms[1], "rsa-sha2-512"); @@ -404,9 +406,11 @@ Host web1.secure.example.com // Verify second host (web1.secure.example.com) let host2 = &config.hosts[1]; assert_eq!(host2.certificate_files.len(), 1); - assert!(host2.certificate_files[0] - .to_string_lossy() - .contains("host-cert.pub")); + assert!( + host2.certificate_files[0] + .to_string_lossy() + .contains("host-cert.pub") + ); assert_eq!(host2.permit_remote_open.len(), 2); assert_eq!(host2.permit_remote_open[0], "localhost:8080"); assert_eq!(host2.permit_remote_open[1], "db.internal:5432"); diff --git a/src/ssh/ssh_config/parser/options/authentication.rs b/src/ssh/ssh_config/parser/options/authentication.rs index de7625ed..dc964ef6 100644 --- a/src/ssh/ssh_config/parser/options/authentication.rs +++ b/src/ssh/ssh_config/parser/options/authentication.rs @@ -133,7 +133,8 @@ pub(super) fn parse_authentication_option( if trimmed.len() > MAX_ALGORITHM_NAME_LENGTH { tracing::warn!( "Algorithm name at line {} exceeds maximum length of {} characters, skipping", - line_number, MAX_ALGORITHM_NAME_LENGTH + line_number, + MAX_ALGORITHM_NAME_LENGTH ); continue; } @@ -165,7 +166,9 @@ pub(super) fn parse_authentication_option( if truncated { tracing::warn!( "PubkeyAcceptedAlgorithms at line {} contains {} algorithms, truncated to first {}", - line_number, total_count, MAX_ALGORITHMS + line_number, + total_count, + MAX_ALGORITHMS ); } @@ -263,7 +266,8 @@ pub(super) fn parse_authentication_option( if trimmed.len() > MAX_ALGORITHM_NAME_LENGTH { tracing::warn!( "Algorithm name at line {} exceeds maximum length of {} characters, skipping", - line_number, MAX_ALGORITHM_NAME_LENGTH + line_number, + MAX_ALGORITHM_NAME_LENGTH ); continue; } @@ -295,7 +299,9 @@ pub(super) fn parse_authentication_option( if truncated { tracing::warn!( "HostbasedAcceptedAlgorithms at line {} contains {} algorithms, truncated to first {}", - line_number, total_count, MAX_ALGORITHMS + line_number, + total_count, + MAX_ALGORITHMS ); } diff --git a/src/ssh/ssh_config/parser/options/command.rs b/src/ssh/ssh_config/parser/options/command.rs index 2c06d9d7..d4e25e6f 100644 --- a/src/ssh/ssh_config/parser/options/command.rs +++ b/src/ssh/ssh_config/parser/options/command.rs @@ -290,34 +290,34 @@ mod tests { #[test] fn test_validate_command_with_tokens_valid() { // Valid commands with tokens - assert!(validate_command_with_tokens( - "rsync -av ~/project/ %h:~/project/", - "LocalCommand", - 1 - ) - .is_ok()); - - assert!(validate_command_with_tokens( - "notify-send \"Connected to %h on port %p\"", - "LocalCommand", - 1 - ) - .is_ok()); - - assert!(validate_command_with_tokens( - "/usr/local/bin/fetch-host-key %H", - "KnownHostsCommand", - 1 - ) - .is_ok()); + assert!( + validate_command_with_tokens("rsync -av ~/project/ %h:~/project/", "LocalCommand", 1) + .is_ok() + ); + + assert!( + validate_command_with_tokens( + "notify-send \"Connected to %h on port %p\"", + "LocalCommand", + 1 + ) + .is_ok() + ); + + assert!( + validate_command_with_tokens( + "/usr/local/bin/fetch-host-key %H", + "KnownHostsCommand", + 1 + ) + .is_ok() + ); // Command with escaped percent - assert!(validate_command_with_tokens( - "echo \"Progress: 50%% complete\"", - "LocalCommand", - 1 - ) - .is_ok()); + assert!( + validate_command_with_tokens("echo \"Progress: 50%% complete\"", "LocalCommand", 1) + .is_ok() + ); } #[test] @@ -418,31 +418,35 @@ mod tests { let mut config = SshHostConfig::default(); // Simple command - assert!(parse_command_option( - &mut config, - "remotecommand", - &["ls".to_string(), "-la".to_string()], - 1 - ) - .is_ok()); + assert!( + parse_command_option( + &mut config, + "remotecommand", + &["ls".to_string(), "-la".to_string()], + 1 + ) + .is_ok() + ); assert_eq!(config.remote_command, Some("ls -la".to_string())); // Complex command (no validation for remote commands) - assert!(parse_command_option( - &mut config, - "remotecommand", - &[ - "tmux".to_string(), - "attach".to_string(), - "-t".to_string(), - "dev".to_string(), - "||".to_string(), - "tmux".to_string(), - "new".to_string() - ], - 1 - ) - .is_ok()); + assert!( + parse_command_option( + &mut config, + "remotecommand", + &[ + "tmux".to_string(), + "attach".to_string(), + "-t".to_string(), + "dev".to_string(), + "||".to_string(), + "tmux".to_string(), + "new".to_string() + ], + 1 + ) + .is_ok() + ); assert_eq!( config.remote_command, Some("tmux attach -t dev || tmux new".to_string()) diff --git a/src/ssh/ssh_config/parser/options/connection.rs b/src/ssh/ssh_config/parser/options/connection.rs index 0693e0c3..17e1291a 100644 --- a/src/ssh/ssh_config/parser/options/connection.rs +++ b/src/ssh/ssh_config/parser/options/connection.rs @@ -231,7 +231,9 @@ pub(super) fn parse_connection_option( if !valid_tos.contains(&num) { tracing::warn!( "IPQoS value '{}' ({:#04x}) at line {} is not a standard DSCP (0-63) or ToS value", - value, num, line_number + value, + num, + line_number ); } } @@ -243,7 +245,8 @@ pub(super) fn parse_connection_option( if num > 63 && ![0x00, 0x04, 0x08, 0x10, 0xff].contains(&num) { tracing::warn!( "IPQoS hex value '{}' at line {} is outside standard ranges", - value, line_number + value, + line_number ); } } else { @@ -325,7 +328,9 @@ pub(super) fn parse_connection_option( tracing::warn!( "RekeyLimit data limit '{}' at line {} is very small ({} bytes). \ This may cause frequent rekeying", - data_limit, line_number, total + data_limit, + line_number, + total ); } } else { @@ -381,7 +386,9 @@ pub(super) fn parse_connection_option( tracing::warn!( "RekeyLimit time limit '{}' at line {} is very long ({} days). \ This may reduce security", - time_limit, line_number, total_seconds / 86400 + time_limit, + line_number, + total_seconds / 86400 ); } // Warn if rekey time is very short (< 60 seconds) @@ -389,7 +396,9 @@ pub(super) fn parse_connection_option( tracing::warn!( "RekeyLimit time limit '{}' at line {} is very short ({} seconds). \ This may cause frequent rekeying", - time_limit, line_number, total_seconds + time_limit, + line_number, + total_seconds ); } } else { diff --git a/src/ssh/ssh_config/parser/options/security.rs b/src/ssh/ssh_config/parser/options/security.rs index fc7cad12..10412600 100644 --- a/src/ssh/ssh_config/parser/options/security.rs +++ b/src/ssh/ssh_config/parser/options/security.rs @@ -90,7 +90,8 @@ pub(super) fn parse_security_option( if trimmed.len() > MAX_ALGORITHM_NAME_LENGTH { tracing::warn!( "HostKeyAlgorithm name at line {} exceeds maximum length of {} characters, skipping", - line_number, MAX_ALGORITHM_NAME_LENGTH + line_number, + MAX_ALGORITHM_NAME_LENGTH ); continue; } @@ -162,7 +163,8 @@ pub(super) fn parse_security_option( if trimmed.len() > MAX_ALGORITHM_NAME_LENGTH { tracing::warn!( "KexAlgorithm name at line {} exceeds maximum length of {} characters, skipping", - line_number, MAX_ALGORITHM_NAME_LENGTH + line_number, + MAX_ALGORITHM_NAME_LENGTH ); continue; } @@ -234,7 +236,8 @@ pub(super) fn parse_security_option( if trimmed.len() > MAX_CIPHER_NAME_LENGTH { tracing::warn!( "Cipher name at line {} exceeds maximum length of {} characters, skipping", - line_number, MAX_CIPHER_NAME_LENGTH + line_number, + MAX_CIPHER_NAME_LENGTH ); continue; } @@ -384,7 +387,8 @@ pub(super) fn parse_security_option( if trimmed.len() > MAX_ALGORITHM_NAME_LENGTH { tracing::warn!( "Algorithm name at line {} exceeds maximum length of {} characters, skipping", - line_number, MAX_ALGORITHM_NAME_LENGTH + line_number, + MAX_ALGORITHM_NAME_LENGTH ); continue; } @@ -416,7 +420,9 @@ pub(super) fn parse_security_option( if truncated { tracing::warn!( "CASignatureAlgorithms at line {} contains {} algorithms, truncated to first {}", - line_number, total_count, MAX_ALGORITHMS + line_number, + total_count, + MAX_ALGORITHMS ); } @@ -499,7 +505,8 @@ pub(super) fn parse_security_option( } tracing::debug!( "Setting HostKeyAlias to '{}' at line {} (security-sensitive: affects host key verification)", - alias, line_number + alias, + line_number ); host.host_key_alias = Some(alias.clone()); } diff --git a/src/ssh/ssh_config/parser/tests.rs b/src/ssh/ssh_config/parser/tests.rs index 9f5fe407..5ddd67d7 100644 --- a/src/ssh/ssh_config/parser/tests.rs +++ b/src/ssh/ssh_config/parser/tests.rs @@ -255,10 +255,12 @@ fn test_parse_very_long_line() { let content = format!("Host example.com\n {long_line}"); let result = parse(&content); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("exceeds maximum length")); + assert!( + result + .unwrap_err() + .to_string() + .contains("exceeds maximum length") + ); } #[test] @@ -268,10 +270,12 @@ fn test_parse_very_long_value() { let content = format!("Host example.com\n User={long_value}"); let result = parse(&content); assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("exceeds maximum length")); + assert!( + result + .unwrap_err() + .to_string() + .contains("exceeds maximum length") + ); } // Integration tests for Include + Match scenarios @@ -485,12 +489,16 @@ Host example.com assert_eq!(hosts.len(), 1); assert_eq!(hosts[0].certificate_files.len(), 2); // Paths should be validated and stored - assert!(hosts[0].certificate_files[0] - .to_string_lossy() - .contains("id_rsa-cert.pub")); - assert!(hosts[0].certificate_files[1] - .to_string_lossy() - .contains("host-cert.pub")); + assert!( + hosts[0].certificate_files[0] + .to_string_lossy() + .contains("id_rsa-cert.pub") + ); + assert!( + hosts[0].certificate_files[1] + .to_string_lossy() + .contains("host-cert.pub") + ); } #[test] @@ -503,9 +511,11 @@ Host example.com let hosts = parse(content).unwrap(); assert_eq!(hosts.len(), 1); assert_eq!(hosts[0].certificate_files.len(), 1); - assert!(hosts[0].certificate_files[0] - .to_string_lossy() - .contains("id_ed25519-cert.pub")); + assert!( + hosts[0].certificate_files[0] + .to_string_lossy() + .contains("id_ed25519-cert.pub") + ); } #[test] diff --git a/src/ssh/ssh_config/resolver_tests.rs b/src/ssh/ssh_config/resolver_tests.rs index 5216fb0e..f2c76e21 100644 --- a/src/ssh/ssh_config/resolver_tests.rs +++ b/src/ssh/ssh_config/resolver_tests.rs @@ -33,12 +33,16 @@ Host example.com // Should have both certificate files (appending behavior) assert_eq!(config.certificate_files.len(), 2); - assert!(config.certificate_files[0] - .to_string_lossy() - .contains("global-cert.pub")); - assert!(config.certificate_files[1] - .to_string_lossy() - .contains("example-cert.pub")); + assert!( + config.certificate_files[0] + .to_string_lossy() + .contains("global-cert.pub") + ); + assert!( + config.certificate_files[1] + .to_string_lossy() + .contains("example-cert.pub") + ); } #[test] @@ -57,13 +61,17 @@ Host example.com // Should deduplicate the shared cert assert_eq!(config.certificate_files.len(), 2); // First should be the shared cert (from first Host *) - assert!(config.certificate_files[0] - .to_string_lossy() - .contains("shared-cert.pub")); + assert!( + config.certificate_files[0] + .to_string_lossy() + .contains("shared-cert.pub") + ); // Second should be the example-specific cert - assert!(config.certificate_files[1] - .to_string_lossy() - .contains("example-cert.pub")); + assert!( + config.certificate_files[1] + .to_string_lossy() + .contains("example-cert.pub") + ); } #[test] @@ -235,15 +243,21 @@ Host * // CertificateFile: all three accumulate assert_eq!(config.certificate_files.len(), 3); - assert!(config.certificate_files[0] - .to_string_lossy() - .contains("first-cert.pub")); - assert!(config.certificate_files[1] - .to_string_lossy() - .contains("second-cert.pub")); - assert!(config.certificate_files[2] - .to_string_lossy() - .contains("third-cert.pub")); + assert!( + config.certificate_files[0] + .to_string_lossy() + .contains("first-cert.pub") + ); + assert!( + config.certificate_files[1] + .to_string_lossy() + .contains("second-cert.pub") + ); + assert!( + config.certificate_files[2] + .to_string_lossy() + .contains("third-cert.pub") + ); } #[test] @@ -433,12 +447,16 @@ Host example.com // Should merge identity files and inherit UseKeychain assert_eq!(config.identity_files.len(), 2); - assert!(config.identity_files[0] - .to_string_lossy() - .contains("id_rsa")); - assert!(config.identity_files[1] - .to_string_lossy() - .contains("id_ed25519")); + assert!( + config.identity_files[0] + .to_string_lossy() + .contains("id_rsa") + ); + assert!( + config.identity_files[1] + .to_string_lossy() + .contains("id_ed25519") + ); assert_eq!(config.use_keychain, Some(true)); } diff --git a/src/ssh/ssh_config/security/checks.rs b/src/ssh/ssh_config/security/checks.rs index e7e69428..98c6b040 100644 --- a/src/ssh/ssh_config/security/checks.rs +++ b/src/ssh/ssh_config/security/checks.rs @@ -52,38 +52,39 @@ pub fn validate_identity_file_security(path: &Path, line_number: usize) -> Resul // On Unix systems, check file permissions if the file exists #[cfg(unix)] - if path.exists() && path.is_file() { - if let Ok(metadata) = std::fs::metadata(path) { - let permissions = metadata.permissions(); - let mode = permissions.mode(); + if path.exists() + && path.is_file() + && let Ok(metadata) = std::fs::metadata(path) + { + let permissions = metadata.permissions(); + let mode = permissions.mode(); - // Check if file is world-readable (dangerous for private keys) - if mode & 0o004 != 0 { - tracing::warn!( - "Security warning: Identity file '{}' at line {} is world-readable. \ + // Check if file is world-readable (dangerous for private keys) + if mode & 0o004 != 0 { + tracing::warn!( + "Security warning: Identity file '{}' at line {} is world-readable. \ Private SSH keys should not be readable by other users (chmod 600 recommended).", - path_str, - line_number - ); - } + path_str, + line_number + ); + } - // Check if file is group-readable (also not ideal for private keys) - if mode & 0o040 != 0 { - tracing::warn!( - "Security warning: Identity file '{}' at line {} is group-readable. \ + // Check if file is group-readable (also not ideal for private keys) + if mode & 0o040 != 0 { + tracing::warn!( + "Security warning: Identity file '{}' at line {} is group-readable. \ Private SSH keys should only be readable by the owner (chmod 600 recommended).", - path_str, - line_number - ); - } + path_str, + line_number + ); + } - // Check if file is world-writable (very dangerous) - if mode & 0o002 != 0 { - anyhow::bail!( - "Security violation: Identity file '{path_str}' at line {line_number} is world-writable. \ + // Check if file is world-writable (very dangerous) + if mode & 0o002 != 0 { + anyhow::bail!( + "Security violation: Identity file '{path_str}' at line {line_number} is world-writable. \ This is extremely dangerous and must be fixed immediately." - ); - } + ); } } diff --git a/src/ssh/ssh_config/security/path_validation.rs b/src/ssh/ssh_config/security/path_validation.rs index 5b5d3a64..946377c8 100644 --- a/src/ssh/ssh_config/security/path_validation.rs +++ b/src/ssh/ssh_config/security/path_validation.rs @@ -69,7 +69,10 @@ pub fn secure_validate_path(path: &str, path_type: &str, line_number: usize) -> Err(e) => { tracing::debug!( "Could not canonicalize {} path '{}' at line {}: {}. Using expanded path as-is.", - path_type, path_str, line_number, e + path_type, + path_str, + line_number, + e ); expanded_path.clone() } diff --git a/src/ssh/tokio_client/authentication.rs b/src/ssh/tokio_client/authentication.rs index a60a5ece..ad1f7868 100644 --- a/src/ssh/tokio_client/authentication.rs +++ b/src/ssh/tokio_client/authentication.rs @@ -328,11 +328,11 @@ pub(super) async fn authenticate( ) .await; - if let Ok(auth_result) = result { - if auth_result.success() { - auth_success = true; - break; - } + if let Ok(auth_result) = result + && auth_result.success() + { + auth_success = true; + break; } } diff --git a/src/ssh/tokio_client/channel_manager.rs b/src/ssh/tokio_client/channel_manager.rs index dfa472d5..cad2e98a 100644 --- a/src/ssh/tokio_client/channel_manager.rs +++ b/src/ssh/tokio_client/channel_manager.rs @@ -21,16 +21,16 @@ //! - Port forwarding channels use bytes::Bytes; -use russh::client::Msg; use russh::Channel; +use russh::client::Msg; use std::io; use std::net::SocketAddr; -use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::mpsc::{Receiver, Sender, channel}; use tokio::task::JoinHandle; -use super::connection::Client; use super::ToSocketAddrsWithHostname; -use crate::security::{contains_sudo_failure, contains_sudo_prompt, SudoPassword}; +use super::connection::Client; +use crate::security::{SudoPassword, contains_sudo_failure, contains_sudo_prompt}; // Buffer size constants for SSH operations /// SSH I/O buffer size constants - optimized for different operation types diff --git a/src/ssh/tokio_client/mod.rs b/src/ssh/tokio_client/mod.rs index b4c121c8..832da4f3 100644 --- a/src/ssh/tokio_client/mod.rs +++ b/src/ssh/tokio_client/mod.rs @@ -25,7 +25,7 @@ mod to_socket_addrs_with_hostname; pub use authentication::{AuthKeyboardInteractive, AuthMethod, ServerCheckMethod}; pub use channel_manager::{CommandExecutedResult, CommandOutput}; pub use connection::{ - Client, ClientHandler, SshConnectionConfig, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_MAX, + Client, ClientHandler, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_MAX, SshConnectionConfig, }; pub use error::Error; pub use to_socket_addrs_with_hostname::ToSocketAddrsWithHostname; diff --git a/src/ui/tui/app.rs b/src/ui/tui/app.rs index d9ec2e34..82007e92 100644 --- a/src/ui/tui/app.rs +++ b/src/ui/tui/app.rs @@ -306,11 +306,11 @@ impl TuiApp { /// Check if there are new log entries and trigger redraw if needed pub fn check_log_updates(&mut self) -> bool { - if let Ok(mut buffer) = self.log_buffer.lock() { - if buffer.take_has_new_entries() { - self.needs_redraw = true; - return true; - } + if let Ok(mut buffer) = self.log_buffer.lock() + && buffer.take_has_new_entries() + { + self.needs_redraw = true; + return true; } false } diff --git a/src/ui/tui/event.rs b/src/ui/tui/event.rs index cbf6686d..19f391f0 100644 --- a/src/ui/tui/event.rs +++ b/src/ui/tui/event.rs @@ -22,10 +22,10 @@ use std::time::Duration; /// /// Returns Some(KeyEvent) if a key was pressed, None if timeout occurred pub fn poll_event(timeout: Duration) -> anyhow::Result> { - if event::poll(timeout)? { - if let Event::Key(key) = event::read()? { - return Ok(Some(key)); - } + if event::poll(timeout)? + && let Event::Key(key) = event::read()? + { + return Ok(Some(key)); } Ok(None) } diff --git a/src/ui/tui/log_layer.rs b/src/ui/tui/log_layer.rs index 203d95e2..f9cc5310 100644 --- a/src/ui/tui/log_layer.rs +++ b/src/ui/tui/log_layer.rs @@ -22,8 +22,8 @@ use super::log_buffer::{LogBuffer, LogEntry}; use std::sync::{Arc, Mutex}; use tracing::field::{Field, Visit}; use tracing::{Event, Level, Subscriber}; -use tracing_subscriber::layer::Context; use tracing_subscriber::Layer; +use tracing_subscriber::layer::Context; /// A tracing layer that captures log events for TUI display /// @@ -128,8 +128,8 @@ where #[cfg(test)] mod tests { use super::*; - use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::Registry; + use tracing_subscriber::layer::SubscriberExt; #[test] fn test_tui_log_layer_captures_events() { diff --git a/src/ui/tui/mod.rs b/src/ui/tui/mod.rs index c23f3bdd..70b7ea29 100644 --- a/src/ui/tui/mod.rs +++ b/src/ui/tui/mod.rs @@ -31,7 +31,7 @@ use crate::utils::get_log_buffer; use anyhow::Result; use app::{TuiApp, ViewMode}; use log_buffer::LogBuffer; -use ratatui::{backend::CrosstermBackend, Terminal}; +use ratatui::{Terminal, backend::CrosstermBackend}; use std::io; use std::sync::{Arc, Mutex}; use std::time::Duration; diff --git a/src/ui/tui/progress.rs b/src/ui/tui/progress.rs index 7c285fe5..caae1c95 100644 --- a/src/ui/tui/progress.rs +++ b/src/ui/tui/progress.rs @@ -59,26 +59,25 @@ pub fn parse_progress(text: &str) -> Option { return None; } // Try apt-specific pattern first (more specific) - if let Some(cap) = APT_PROGRESS.captures(text) { - if let Ok(percent) = cap[1].parse::() { - return Some(percent.min(100.0)); - } + if let Some(cap) = APT_PROGRESS.captures(text) + && let Ok(percent) = cap[1].parse::() + { + return Some(percent.min(100.0)); } // Try general percent pattern: "78%" - if let Some(cap) = PERCENT_PATTERN.captures(text) { - if let Ok(percent) = cap[1].parse::() { - return Some(percent.min(100.0)); - } + if let Some(cap) = PERCENT_PATTERN.captures(text) + && let Ok(percent) = cap[1].parse::() + { + return Some(percent.min(100.0)); } // Try fraction pattern: "23/100" - if let Some(cap) = FRACTION_PATTERN.captures(text) { - if let (Ok(current), Ok(total)) = (cap[1].parse::(), cap[2].parse::()) { - if total > 0.0 { - return Some((current / total * 100.0).min(100.0)); - } - } + if let Some(cap) = FRACTION_PATTERN.captures(text) + && let (Ok(current), Ok(total)) = (cap[1].parse::(), cap[2].parse::()) + && total > 0.0 + { + return Some((current / total * 100.0).min(100.0)); } None diff --git a/src/ui/tui/terminal_guard.rs b/src/ui/tui/terminal_guard.rs index 216161f6..d3202850 100644 --- a/src/ui/tui/terminal_guard.rs +++ b/src/ui/tui/terminal_guard.rs @@ -20,7 +20,7 @@ use crossterm::{ execute, - terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, + terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode}, }; use std::io::{self, Write}; use tracing::{error, warn}; diff --git a/src/ui/tui/views/detail.rs b/src/ui/tui/views/detail.rs index 58d88861..bb2b0045 100644 --- a/src/ui/tui/views/detail.rs +++ b/src/ui/tui/views/detail.rs @@ -16,11 +16,11 @@ use crate::executor::{ExecutionStatus, NodeStream}; use ratatui::{ + Frame, layout::{Constraint, Layout, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, - Frame, }; /// Render the detail view for a single node diff --git a/src/ui/tui/views/diff.rs b/src/ui/tui/views/diff.rs index 02bd4b05..673b1016 100644 --- a/src/ui/tui/views/diff.rs +++ b/src/ui/tui/views/diff.rs @@ -16,11 +16,11 @@ use crate::executor::NodeStream; use ratatui::{ + Frame, layout::{Constraint, Direction, Layout, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, - Frame, }; /// Render the diff view comparing two nodes diff --git a/src/ui/tui/views/log_panel.rs b/src/ui/tui/views/log_panel.rs index 5d0c58d8..6ad03380 100644 --- a/src/ui/tui/views/log_panel.rs +++ b/src/ui/tui/views/log_panel.rs @@ -19,11 +19,11 @@ use crate::ui::tui::log_buffer::LogBuffer; use ratatui::{ + Frame, layout::{Alignment, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, Paragraph}, - Frame, }; use std::sync::{Arc, Mutex}; use tracing::Level; diff --git a/src/ui/tui/views/split.rs b/src/ui/tui/views/split.rs index 070c2ad3..c454fa7d 100644 --- a/src/ui/tui/views/split.rs +++ b/src/ui/tui/views/split.rs @@ -16,11 +16,11 @@ use crate::executor::{ExecutionStatus, MultiNodeStreamManager}; use ratatui::{ + Frame, layout::{Constraint, Direction, Layout, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, - Frame, }; /// Render the split view diff --git a/src/ui/tui/views/summary.rs b/src/ui/tui/views/summary.rs index 7b9579a8..284d1ffb 100644 --- a/src/ui/tui/views/summary.rs +++ b/src/ui/tui/views/summary.rs @@ -17,11 +17,11 @@ use crate::executor::{ExecutionStatus, MultiNodeStreamManager}; use crate::ui::tui::progress::{extract_status_message, parse_progress_from_output}; use ratatui::{ + Frame, layout::{Constraint, Layout, Rect}, style::{Color, Modifier, Style}, text::{Line, Span}, widgets::{Block, Borders, Paragraph, Wrap}, - Frame, }; /// Render the summary view diff --git a/src/utils/buffer_pool.rs b/src/utils/buffer_pool.rs index 88922eb4..5a8a1a13 100644 --- a/src/utils/buffer_pool.rs +++ b/src/utils/buffer_pool.rs @@ -96,10 +96,10 @@ impl Drop for PooledBuffer { // Clear the buffer and return it to the pool self.buffer.clear(); - if let Ok(mut pool) = self.pool.lock() { - if pool.len() < MAX_POOL_SIZE { - pool.push(std::mem::take(&mut self.buffer)); - } + if let Ok(mut pool) = self.pool.lock() + && pool.len() < MAX_POOL_SIZE + { + pool.push(std::mem::take(&mut self.buffer)); } } } diff --git a/src/utils/logging.rs b/src/utils/logging.rs index 684f1ae1..51bc42eb 100644 --- a/src/utils/logging.rs +++ b/src/utils/logging.rs @@ -16,7 +16,7 @@ use crate::ui::tui::log_buffer::LogBuffer; use crate::ui::tui::log_layer::TuiLogLayer; use once_cell::sync::OnceCell; use std::sync::{Arc, Mutex}; -use tracing_subscriber::{prelude::*, EnvFilter}; +use tracing_subscriber::{EnvFilter, prelude::*}; /// Global log buffer for TUI mode static LOG_BUFFER: OnceCell>> = OnceCell::new(); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8e2dcdfb..68657f9d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -18,7 +18,7 @@ pub mod logging; pub mod output; pub mod sanitize; -pub use buffer_pool::{global_buffer_pool, BufferPool, PooledBuffer}; +pub use buffer_pool::{BufferPool, PooledBuffer, global_buffer_pool}; pub use fs::{format_bytes, resolve_source_files, walk_directory}; pub use logging::{disable_fmt_logging, enable_fmt_logging, get_log_buffer, init_logging}; pub use output::save_outputs_to_files; diff --git a/src/utils/sanitize.rs b/src/utils/sanitize.rs index 60ffcde3..79821e63 100644 --- a/src/utils/sanitize.rs +++ b/src/utils/sanitize.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::{bail, Result}; +use anyhow::{Result, bail}; use tracing::warn; /// Sanitize and validate SSH commands to prevent injection attacks @@ -169,10 +169,11 @@ pub fn sanitize_username(username: &str) -> Result { } // Username should start with letter or underscore (Unix convention) - if let Some(first_char) = username.chars().next() { - if !first_char.is_ascii_alphabetic() && first_char != '_' { - bail!("Username must start with letter or underscore"); - } + if let Some(first_char) = username.chars().next() + && !first_char.is_ascii_alphabetic() + && first_char != '_' + { + bail!("Username must start with letter or underscore"); } Ok(username.to_string()) diff --git a/tests/interactive_integration_test.rs b/tests/interactive_integration_test.rs index 61b83746..fe6c72db 100644 --- a/tests/interactive_integration_test.rs +++ b/tests/interactive_integration_test.rs @@ -21,8 +21,8 @@ use bssh::pty::PtyConfig; use bssh::ssh::known_hosts::StrictHostKeyChecking; use bssh::ssh::tokio_client::SshConnectionConfig; use std::path::PathBuf; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::time::Duration; use tempfile::tempdir; diff --git a/tests/interactive_signal_test.rs b/tests/interactive_signal_test.rs index 3bf260c8..29bd1d32 100644 --- a/tests/interactive_signal_test.rs +++ b/tests/interactive_signal_test.rs @@ -13,10 +13,10 @@ // limitations under the License. use bssh::commands::interactive_signal::{ - is_interrupted, reset_interrupt, setup_signal_handlers, TerminalGuard, + TerminalGuard, is_interrupted, reset_interrupt, setup_signal_handlers, }; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; #[test] diff --git a/tests/jump_host_auth_mutex_test.rs b/tests/jump_host_auth_mutex_test.rs index b6d4740c..7e4b62f0 100644 --- a/tests/jump_host_auth_mutex_test.rs +++ b/tests/jump_host_auth_mutex_test.rs @@ -19,10 +19,10 @@ //! - Prevents race conditions when multiple jump hosts need credentials //! - Ensures prompts don't overlap or interfere with each other -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::sync::Mutex; -use tokio::time::{sleep, Duration}; +use tokio::time::{Duration, sleep}; /// Simulates an authentication prompt that takes some time async fn simulate_auth_prompt( diff --git a/tests/jump_host_config_test.rs b/tests/jump_host_config_test.rs index 8b0db156..ee5169f6 100644 --- a/tests/jump_host_config_test.rs +++ b/tests/jump_host_config_test.rs @@ -683,12 +683,14 @@ clusters: match &cluster.defaults.jump_host { Some(JumpHostConfig::Simple(s)) => { assert_eq!(s, "@bastion"); - assert!(cluster - .defaults - .jump_host - .as_ref() - .unwrap() - .is_ssh_config_ref()); + assert!( + cluster + .defaults + .jump_host + .as_ref() + .unwrap() + .is_ssh_config_ref() + ); assert_eq!( cluster .defaults diff --git a/tests/pdsh_compat_test.rs b/tests/pdsh_compat_test.rs index ca70d36c..8ae730a0 100644 --- a/tests/pdsh_compat_test.rs +++ b/tests/pdsh_compat_test.rs @@ -17,7 +17,7 @@ //! These tests verify that bssh correctly handles pdsh-style arguments //! and behaves as expected in pdsh compatibility mode. -use bssh::cli::{has_pdsh_compat_flag, remove_pdsh_compat_flag, PdshCli, PDSH_COMPAT_ENV_VAR}; +use bssh::cli::{PDSH_COMPAT_ENV_VAR, PdshCli, has_pdsh_compat_flag, remove_pdsh_compat_flag}; use serial_test::serial; use std::env; diff --git a/tests/pty_stress_test.rs b/tests/pty_stress_test.rs index 08308295..1c61f277 100644 --- a/tests/pty_stress_test.rs +++ b/tests/pty_stress_test.rs @@ -26,7 +26,7 @@ use bssh::pty::PtyMessage; use smallvec::SmallVec; use std::time::Duration; use tokio::sync::mpsc; -use tokio::time::{timeout, Instant}; +use tokio::time::{Instant, timeout}; // Helper to generate random data #[allow(dead_code)] diff --git a/tests/pty_utils_test.rs b/tests/pty_utils_test.rs index 94d6459a..4429c4d4 100644 --- a/tests/pty_utils_test.rs +++ b/tests/pty_utils_test.rs @@ -21,7 +21,7 @@ //! - Terminal detection utilities //! - Cross-platform compatibility -use bssh::pty::{utils::*, PtyConfig}; +use bssh::pty::{PtyConfig, utils::*}; use signal_hook::consts::SIGWINCH; use std::time::Duration; diff --git a/tests/ssh_keepalive_test.rs b/tests/ssh_keepalive_test.rs index 74764fd7..376a462f 100644 --- a/tests/ssh_keepalive_test.rs +++ b/tests/ssh_keepalive_test.rs @@ -23,7 +23,7 @@ use bssh::ssh::ssh_config::SshConfig; use bssh::ssh::tokio_client::{ - SshConnectionConfig, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_MAX, + DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_MAX, SshConnectionConfig, }; // ============================================================================= @@ -813,8 +813,8 @@ fn test_jump_host_chain_with_ssh_connection_config() { #[test] fn test_jump_host_chain_with_custom_keepalive_for_long_running_sessions() { // Test real-world use case: long-running sessions need longer keepalive - use bssh::jump::parser::JumpHost; use bssh::jump::JumpHostChain; + use bssh::jump::parser::JumpHost; use std::time::Duration; // For long-running interactive sessions, use longer keepalive intervals diff --git a/tests/tui_snapshot_tests.rs b/tests/tui_snapshot_tests.rs index 90714ea1..ea4aee06 100644 --- a/tests/tui_snapshot_tests.rs +++ b/tests/tui_snapshot_tests.rs @@ -21,7 +21,7 @@ use bssh::executor::{MultiNodeStreamManager, NodeStream}; use bssh::node::Node; use bssh::ssh::tokio_client::CommandOutput; use bssh::ui::tui::app::{TuiApp, ViewMode}; -use ratatui::{backend::TestBackend, buffer::Buffer, Terminal}; +use ratatui::{Terminal, backend::TestBackend, buffer::Buffer}; use tokio::sync::mpsc; /// Helper to convert buffer to a displayable string for snapshot comparison