diff --git a/deployer/src/aws/mod.rs b/deployer/src/aws/mod.rs index 5ceb72cc717..ba242565c17 100644 --- a/deployer/src/aws/mod.rs +++ b/deployer/src/aws/mod.rs @@ -481,6 +481,12 @@ cfg_if::cfg_if! { Reqwest(#[from] reqwest::Error), #[error("SSH failed")] SshFailed, + #[error("command timeout({ip}): {program} after {seconds}s")] + CommandTimeout { + program: String, + ip: String, + seconds: u64, + }, #[error("keygen failed")] KeygenFailed, #[error("service timeout({0}): {1}")] diff --git a/deployer/src/aws/profile.rs b/deployer/src/aws/profile.rs index c07234e6ef3..b34006a4b2c 100644 --- a/deployer/src/aws/profile.rs +++ b/deployer/src/aws/profile.rs @@ -5,14 +5,14 @@ use crate::aws::{ ec2::{self, *}, s3::{self, *}, services::*, - utils::{download_file, scp_download, ssh_execute}, + utils::{download_file, scp_download, ssh_execute_with_timeout}, Config, Error, CREATED_FILE_NAME, DESTROYED_FILE_NAME, MONITORING_REGION, }; use aws_sdk_ec2::types::Filter; use std::{ fs::File, path::{Path, PathBuf}, - time::{SystemTime, UNIX_EPOCH}, + time::{Duration, SystemTime, UNIX_EPOCH}, }; use tokio::process::Command; use tracing::info; @@ -172,7 +172,13 @@ echo "Profile captured successfully" duration = duration, "starting profile capture" ); - ssh_execute(private_key, &instance_ip, &profile_script).await?; + ssh_execute_with_timeout( + private_key, + &instance_ip, + &profile_script, + Duration::from_secs(duration).saturating_add(Duration::from_secs(5 * 60)), + ) + .await?; info!("profile capture complete"); // Download the profile locally via scp diff --git a/deployer/src/aws/utils.rs b/deployer/src/aws/utils.rs index f4cdb560277..72dd0cf12f9 100644 --- a/deployer/src/aws/utils.rs +++ b/deployer/src/aws/utils.rs @@ -1,12 +1,12 @@ //! Utility functions for interacting with EC2 instances use crate::aws::Error; -use std::path::Path; +use std::{path::Path, process::Output}; use tokio::{ fs::File, io::AsyncWriteExt, process::Command, - time::{sleep, Duration}, + time::{sleep, timeout, Duration}, }; use tracing::{info, warn}; @@ -19,6 +19,15 @@ pub const MAX_POLL_ATTEMPTS: usize = 30; /// Interval between retries pub const RETRY_INTERVAL: Duration = Duration::from_secs(15); +/// Maximum time to wait for a non-polling SSH command to complete +pub const SSH_COMMAND_TIMEOUT: Duration = Duration::from_secs(30 * 60); + +/// Maximum time to wait for a service status poll to complete +pub const SSH_POLL_TIMEOUT: Duration = Duration::from_secs(30); + +/// Maximum time to wait for an SCP download to complete +pub const SCP_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(30 * 60); + /// Protocol for deployer ingress pub const DEPLOYER_PROTOCOL: &str = "tcp"; @@ -42,9 +51,19 @@ pub async fn get_public_ip() -> Result { /// Executes a command on a remote instance via SSH with retries pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), Error> { + ssh_execute_with_timeout(key_file, ip, command, SSH_COMMAND_TIMEOUT).await +} + +/// Executes a command on a remote instance via SSH with retries and a per-attempt timeout +pub async fn ssh_execute_with_timeout( + key_file: &str, + ip: &str, + command: &str, + command_timeout: Duration, +) -> Result<(), Error> { for _ in 0..MAX_SSH_ATTEMPTS { - let output = Command::new("ssh") - .arg("-i") + let mut cmd = Command::new("ssh"); + cmd.arg("-i") .arg(key_file) .arg("-o") .arg("IdentitiesOnly=yes") @@ -53,9 +72,16 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), .arg("-o") .arg("StrictHostKeyChecking=no") .arg(format!("ubuntu@{ip}")) - .arg(command) - .output() - .await?; + .arg(command); + let output = match command_output(cmd, "ssh", ip, command_timeout).await { + Ok(output) => output, + Err(err @ Error::CommandTimeout { .. }) => { + warn!(ip, error = ?err, "SSH command timed out"); + sleep(RETRY_INTERVAL).await; + continue; + } + Err(err) => return Err(err), + }; if output.status.success() { return Ok(()); } @@ -68,8 +94,8 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), /// Polls the status of a systemd service on a remote instance until active pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Result<(), Error> { for _ in 0..MAX_POLL_ATTEMPTS { - let output = Command::new("ssh") - .arg("-i") + let mut cmd = Command::new("ssh"); + cmd.arg("-i") .arg(key_file) .arg("-o") .arg("IdentitiesOnly=yes") @@ -78,9 +104,16 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res .arg("-o") .arg("StrictHostKeyChecking=no") .arg(format!("ubuntu@{ip}")) - .arg(format!("systemctl is-active {service}")) - .output() - .await?; + .arg(format!("systemctl is-active {service}")); + let output = match command_output(cmd, "ssh", ip, SSH_POLL_TIMEOUT).await { + Ok(output) => output, + Err(err @ Error::CommandTimeout { .. }) => { + warn!(service, error = ?err, "service status poll timed out"); + sleep(RETRY_INTERVAL).await; + continue; + } + Err(err) => return Err(err), + }; let parsed = String::from_utf8_lossy(&output.stdout); let parsed = parsed.trim(); if parsed == "active" { @@ -99,8 +132,8 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res /// Polls the status of a systemd service on a remote instance until it becomes inactive pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> Result<(), Error> { for _ in 0..MAX_POLL_ATTEMPTS { - let output = Command::new("ssh") - .arg("-i") + let mut cmd = Command::new("ssh"); + cmd.arg("-i") .arg(key_file) .arg("-o") .arg("IdentitiesOnly=yes") @@ -109,9 +142,16 @@ pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> R .arg("-o") .arg("StrictHostKeyChecking=no") .arg(format!("ubuntu@{ip}")) - .arg(format!("systemctl is-active {service}")) - .output() - .await?; + .arg(format!("systemctl is-active {service}")); + let output = match command_output(cmd, "ssh", ip, SSH_POLL_TIMEOUT).await { + Ok(output) => output, + Err(err @ Error::CommandTimeout { .. }) => { + warn!(service, error = ?err, "service status poll timed out"); + sleep(RETRY_INTERVAL).await; + continue; + } + Err(err) => return Err(err), + }; let parsed = String::from_utf8_lossy(&output.stdout); let parsed = parsed.trim(); if parsed == "inactive" { @@ -135,8 +175,8 @@ pub async fn scp_download( local_path: &str, ) -> Result<(), Error> { for _ in 0..MAX_SSH_ATTEMPTS { - let output = Command::new("scp") - .arg("-i") + let mut cmd = Command::new("scp"); + cmd.arg("-i") .arg(key_file) .arg("-o") .arg("IdentitiesOnly=yes") @@ -145,9 +185,16 @@ pub async fn scp_download( .arg("-o") .arg("StrictHostKeyChecking=no") .arg(format!("ubuntu@{ip}:{remote_path}")) - .arg(local_path) - .output() - .await?; + .arg(local_path); + let output = match command_output(cmd, "scp", ip, SCP_DOWNLOAD_TIMEOUT).await { + Ok(output) => output, + Err(err @ Error::CommandTimeout { .. }) => { + warn!(ip, error = ?err, "SCP timed out"); + sleep(RETRY_INTERVAL).await; + continue; + } + Err(err) => return Err(err), + }; if output.status.success() { return Ok(()); } @@ -157,6 +204,23 @@ pub async fn scp_download( Err(Error::SshFailed) } +async fn command_output( + mut command: Command, + program: &str, + ip: &str, + command_timeout: Duration, +) -> Result { + command.kill_on_drop(true); + match timeout(command_timeout, command.output()).await { + Ok(output) => Ok(output?), + Err(_) => Err(Error::CommandTimeout { + program: program.to_string(), + ip: ip.to_string(), + seconds: command_timeout.as_secs(), + }), + } +} + /// Converts an IP address to a CIDR block pub fn exact_cidr(ip: &str) -> String { format!("{ip}/32")