Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deployer/src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ cfg_if::cfg_if! {
Reqwest(#[from] reqwest::Error),
#[error("SSH failed")]
SshFailed,
#[error("command timeout({ip}): {program} after {seconds}s")]
CommandTimeout {
program: String,
ip: String,
seconds: u64,
},
#[error("keygen failed")]
KeygenFailed,
#[error("service timeout({0}): {1}")]
Expand Down
12 changes: 9 additions & 3 deletions deployer/src/aws/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use crate::aws::{
ec2::{self, *},
s3::{self, *},
services::*,
utils::{download_file, scp_download, ssh_execute},
utils::{download_file, scp_download, ssh_execute_with_timeout},
Config, Error, CREATED_FILE_NAME, DESTROYED_FILE_NAME, MONITORING_REGION,
};
use aws_sdk_ec2::types::Filter;
use std::{
fs::File,
path::{Path, PathBuf},
time::{SystemTime, UNIX_EPOCH},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tokio::process::Command;
use tracing::info;
Expand Down Expand Up @@ -172,7 +172,13 @@ echo "Profile captured successfully"
duration = duration,
"starting profile capture"
);
ssh_execute(private_key, &instance_ip, &profile_script).await?;
ssh_execute_with_timeout(
private_key,
&instance_ip,
&profile_script,
Duration::from_secs(duration).saturating_add(Duration::from_secs(5 * 60)),
)
.await?;
info!("profile capture complete");

// Download the profile locally via scp
Expand Down
108 changes: 86 additions & 22 deletions deployer/src/aws/utils.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! Utility functions for interacting with EC2 instances

use crate::aws::Error;
use std::path::Path;
use std::{path::Path, process::Output};
use tokio::{
fs::File,
io::AsyncWriteExt,
process::Command,
time::{sleep, Duration},
time::{sleep, timeout, Duration},
};
use tracing::{info, warn};

Expand All @@ -19,6 +19,15 @@ pub const MAX_POLL_ATTEMPTS: usize = 30;
/// Interval between retries
pub const RETRY_INTERVAL: Duration = Duration::from_secs(15);

/// Maximum time to wait for a non-polling SSH command to complete
pub const SSH_COMMAND_TIMEOUT: Duration = Duration::from_secs(30 * 60);

/// Maximum time to wait for a service status poll to complete
pub const SSH_POLL_TIMEOUT: Duration = Duration::from_secs(30);

/// Maximum time to wait for an SCP download to complete
pub const SCP_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(30 * 60);

/// Protocol for deployer ingress
pub const DEPLOYER_PROTOCOL: &str = "tcp";

Expand All @@ -42,9 +51,19 @@ pub async fn get_public_ip() -> Result<String, Error> {

/// Executes a command on a remote instance via SSH with retries
pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), Error> {
ssh_execute_with_timeout(key_file, ip, command, SSH_COMMAND_TIMEOUT).await
}

/// Executes a command on a remote instance via SSH with retries and a per-attempt timeout
pub async fn ssh_execute_with_timeout(
key_file: &str,
ip: &str,
command: &str,
command_timeout: Duration,
) -> Result<(), Error> {
for _ in 0..MAX_SSH_ATTEMPTS {
let output = Command::new("ssh")
.arg("-i")
let mut cmd = Command::new("ssh");
cmd.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
Expand All @@ -53,9 +72,16 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(),
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}"))
.arg(command)
.output()
.await?;
.arg(command);
let output = match command_output(cmd, "ssh", ip, command_timeout).await {
Ok(output) => output,
Err(err @ Error::CommandTimeout { .. }) => {
warn!(ip, error = ?err, "SSH command timed out");
sleep(RETRY_INTERVAL).await;
continue;
}
Err(err) => return Err(err),
};
if output.status.success() {
return Ok(());
}
Expand All @@ -68,8 +94,8 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(),
/// Polls the status of a systemd service on a remote instance until active
pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
for _ in 0..MAX_POLL_ATTEMPTS {
let output = Command::new("ssh")
.arg("-i")
let mut cmd = Command::new("ssh");
cmd.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
Expand All @@ -78,9 +104,16 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}"))
.arg(format!("systemctl is-active {service}"))
.output()
.await?;
.arg(format!("systemctl is-active {service}"));
let output = match command_output(cmd, "ssh", ip, SSH_POLL_TIMEOUT).await {
Ok(output) => output,
Err(err @ Error::CommandTimeout { .. }) => {
warn!(service, error = ?err, "service status poll timed out");
sleep(RETRY_INTERVAL).await;
continue;
}
Err(err) => return Err(err),
};
let parsed = String::from_utf8_lossy(&output.stdout);
let parsed = parsed.trim();
if parsed == "active" {
Expand All @@ -99,8 +132,8 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res
/// Polls the status of a systemd service on a remote instance until it becomes inactive
pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
for _ in 0..MAX_POLL_ATTEMPTS {
let output = Command::new("ssh")
.arg("-i")
let mut cmd = Command::new("ssh");
cmd.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
Expand All @@ -109,9 +142,16 @@ pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> R
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}"))
.arg(format!("systemctl is-active {service}"))
.output()
.await?;
.arg(format!("systemctl is-active {service}"));
let output = match command_output(cmd, "ssh", ip, SSH_POLL_TIMEOUT).await {
Ok(output) => output,
Err(err @ Error::CommandTimeout { .. }) => {
warn!(service, error = ?err, "service status poll timed out");
sleep(RETRY_INTERVAL).await;
continue;
}
Err(err) => return Err(err),
};
let parsed = String::from_utf8_lossy(&output.stdout);
let parsed = parsed.trim();
if parsed == "inactive" {
Expand All @@ -135,8 +175,8 @@ pub async fn scp_download(
local_path: &str,
) -> Result<(), Error> {
for _ in 0..MAX_SSH_ATTEMPTS {
let output = Command::new("scp")
.arg("-i")
let mut cmd = Command::new("scp");
cmd.arg("-i")
.arg(key_file)
.arg("-o")
.arg("IdentitiesOnly=yes")
Expand All @@ -145,9 +185,16 @@ pub async fn scp_download(
.arg("-o")
.arg("StrictHostKeyChecking=no")
.arg(format!("ubuntu@{ip}:{remote_path}"))
.arg(local_path)
.output()
.await?;
.arg(local_path);
let output = match command_output(cmd, "scp", ip, SCP_DOWNLOAD_TIMEOUT).await {
Ok(output) => output,
Err(err @ Error::CommandTimeout { .. }) => {
warn!(ip, error = ?err, "SCP timed out");
sleep(RETRY_INTERVAL).await;
continue;
}
Err(err) => return Err(err),
};
if output.status.success() {
return Ok(());
}
Expand All @@ -157,6 +204,23 @@ pub async fn scp_download(
Err(Error::SshFailed)
}

async fn command_output(
mut command: Command,
program: &str,
ip: &str,
command_timeout: Duration,
) -> Result<Output, Error> {
command.kill_on_drop(true);
match timeout(command_timeout, command.output()).await {
Ok(output) => Ok(output?),
Err(_) => Err(Error::CommandTimeout {
program: program.to_string(),
ip: ip.to_string(),
seconds: command_timeout.as_secs(),
}),
}
}

/// Converts an IP address to a CIDR block
pub fn exact_cidr(ip: &str) -> String {
format!("{ip}/32")
Expand Down
Loading