Skip to content

Commit b2c90ca

Browse files
committed
[deployer] Bound SSH child execution
1 parent de6ecda commit b2c90ca

3 files changed

Lines changed: 93 additions & 25 deletions

File tree

deployer/src/aws/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ cfg_if::cfg_if! {
481481
Reqwest(#[from] reqwest::Error),
482482
#[error("SSH failed")]
483483
SshFailed,
484+
#[error("command timeout({ip}): {program} after {seconds}s")]
485+
CommandTimeout {
486+
program: String,
487+
ip: String,
488+
seconds: u64,
489+
},
484490
#[error("keygen failed")]
485491
KeygenFailed,
486492
#[error("service timeout({0}): {1}")]

deployer/src/aws/profile.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ use crate::aws::{
55
ec2::{self, *},
66
s3::{self, *},
77
services::*,
8-
utils::{download_file, scp_download, ssh_execute},
8+
utils::{download_file, scp_download, ssh_execute_with_timeout},
99
Config, Error, CREATED_FILE_NAME, DESTROYED_FILE_NAME, MONITORING_REGION,
1010
};
1111
use aws_sdk_ec2::types::Filter;
1212
use std::{
1313
fs::File,
1414
path::{Path, PathBuf},
15-
time::{SystemTime, UNIX_EPOCH},
15+
time::{Duration, SystemTime, UNIX_EPOCH},
1616
};
1717
use tokio::process::Command;
1818
use tracing::info;
@@ -172,7 +172,13 @@ echo "Profile captured successfully"
172172
duration = duration,
173173
"starting profile capture"
174174
);
175-
ssh_execute(private_key, &instance_ip, &profile_script).await?;
175+
ssh_execute_with_timeout(
176+
private_key,
177+
&instance_ip,
178+
&profile_script,
179+
Duration::from_secs(duration).saturating_add(Duration::from_secs(5 * 60)),
180+
)
181+
.await?;
176182
info!("profile capture complete");
177183

178184
// Download the profile locally via scp

deployer/src/aws/utils.rs

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
//! Utility functions for interacting with EC2 instances
22
33
use crate::aws::Error;
4-
use std::path::Path;
4+
use std::{path::Path, process::Output};
55
use tokio::{
66
fs::File,
77
io::AsyncWriteExt,
88
process::Command,
9-
time::{sleep, Duration},
9+
time::{sleep, timeout, Duration},
1010
};
1111
use tracing::{info, warn};
1212

@@ -19,6 +19,15 @@ pub const MAX_POLL_ATTEMPTS: usize = 30;
1919
/// Interval between retries
2020
pub const RETRY_INTERVAL: Duration = Duration::from_secs(15);
2121

22+
/// Maximum time to wait for a non-polling SSH command to complete
23+
pub const SSH_COMMAND_TIMEOUT: Duration = Duration::from_secs(30 * 60);
24+
25+
/// Maximum time to wait for a service status poll to complete
26+
pub const SSH_POLL_TIMEOUT: Duration = Duration::from_secs(30);
27+
28+
/// Maximum time to wait for an SCP download to complete
29+
pub const SCP_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(30 * 60);
30+
2231
/// Protocol for deployer ingress
2332
pub const DEPLOYER_PROTOCOL: &str = "tcp";
2433

@@ -42,9 +51,19 @@ pub async fn get_public_ip() -> Result<String, Error> {
4251

4352
/// Executes a command on a remote instance via SSH with retries
4453
pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(), Error> {
54+
ssh_execute_with_timeout(key_file, ip, command, SSH_COMMAND_TIMEOUT).await
55+
}
56+
57+
/// Executes a command on a remote instance via SSH with retries and a per-attempt timeout
58+
pub async fn ssh_execute_with_timeout(
59+
key_file: &str,
60+
ip: &str,
61+
command: &str,
62+
command_timeout: Duration,
63+
) -> Result<(), Error> {
4564
for _ in 0..MAX_SSH_ATTEMPTS {
46-
let output = Command::new("ssh")
47-
.arg("-i")
65+
let mut cmd = Command::new("ssh");
66+
cmd.arg("-i")
4867
.arg(key_file)
4968
.arg("-o")
5069
.arg("IdentitiesOnly=yes")
@@ -53,9 +72,8 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(),
5372
.arg("-o")
5473
.arg("StrictHostKeyChecking=no")
5574
.arg(format!("ubuntu@{ip}"))
56-
.arg(command)
57-
.output()
58-
.await?;
75+
.arg(command);
76+
let output = command_output(cmd, "ssh", ip, command_timeout).await?;
5977
if output.status.success() {
6078
return Ok(());
6179
}
@@ -68,8 +86,8 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(),
6886
/// Polls the status of a systemd service on a remote instance until active
6987
pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
7088
for _ in 0..MAX_POLL_ATTEMPTS {
71-
let output = Command::new("ssh")
72-
.arg("-i")
89+
let mut cmd = Command::new("ssh");
90+
cmd.arg("-i")
7391
.arg(key_file)
7492
.arg("-o")
7593
.arg("IdentitiesOnly=yes")
@@ -78,9 +96,16 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res
7896
.arg("-o")
7997
.arg("StrictHostKeyChecking=no")
8098
.arg(format!("ubuntu@{ip}"))
81-
.arg(format!("systemctl is-active {service}"))
82-
.output()
83-
.await?;
99+
.arg(format!("systemctl is-active {service}"));
100+
let output = match command_output(cmd, "ssh", ip, SSH_POLL_TIMEOUT).await {
101+
Ok(output) => output,
102+
Err(err @ Error::CommandTimeout { .. }) => {
103+
warn!(service, error = ?err, "service status poll timed out");
104+
sleep(RETRY_INTERVAL).await;
105+
continue;
106+
}
107+
Err(err) => return Err(err),
108+
};
84109
let parsed = String::from_utf8_lossy(&output.stdout);
85110
let parsed = parsed.trim();
86111
if parsed == "active" {
@@ -99,8 +124,8 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res
99124
/// Polls the status of a systemd service on a remote instance until it becomes inactive
100125
pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> Result<(), Error> {
101126
for _ in 0..MAX_POLL_ATTEMPTS {
102-
let output = Command::new("ssh")
103-
.arg("-i")
127+
let mut cmd = Command::new("ssh");
128+
cmd.arg("-i")
104129
.arg(key_file)
105130
.arg("-o")
106131
.arg("IdentitiesOnly=yes")
@@ -109,9 +134,16 @@ pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> R
109134
.arg("-o")
110135
.arg("StrictHostKeyChecking=no")
111136
.arg(format!("ubuntu@{ip}"))
112-
.arg(format!("systemctl is-active {service}"))
113-
.output()
114-
.await?;
137+
.arg(format!("systemctl is-active {service}"));
138+
let output = match command_output(cmd, "ssh", ip, SSH_POLL_TIMEOUT).await {
139+
Ok(output) => output,
140+
Err(err @ Error::CommandTimeout { .. }) => {
141+
warn!(service, error = ?err, "service status poll timed out");
142+
sleep(RETRY_INTERVAL).await;
143+
continue;
144+
}
145+
Err(err) => return Err(err),
146+
};
115147
let parsed = String::from_utf8_lossy(&output.stdout);
116148
let parsed = parsed.trim();
117149
if parsed == "inactive" {
@@ -135,8 +167,8 @@ pub async fn scp_download(
135167
local_path: &str,
136168
) -> Result<(), Error> {
137169
for _ in 0..MAX_SSH_ATTEMPTS {
138-
let output = Command::new("scp")
139-
.arg("-i")
170+
let mut cmd = Command::new("scp");
171+
cmd.arg("-i")
140172
.arg(key_file)
141173
.arg("-o")
142174
.arg("IdentitiesOnly=yes")
@@ -145,9 +177,16 @@ pub async fn scp_download(
145177
.arg("-o")
146178
.arg("StrictHostKeyChecking=no")
147179
.arg(format!("ubuntu@{ip}:{remote_path}"))
148-
.arg(local_path)
149-
.output()
150-
.await?;
180+
.arg(local_path);
181+
let output = match command_output(cmd, "scp", ip, SCP_DOWNLOAD_TIMEOUT).await {
182+
Ok(output) => output,
183+
Err(err @ Error::CommandTimeout { .. }) => {
184+
warn!(ip, error = ?err, "SCP timed out");
185+
sleep(RETRY_INTERVAL).await;
186+
continue;
187+
}
188+
Err(err) => return Err(err),
189+
};
151190
if output.status.success() {
152191
return Ok(());
153192
}
@@ -157,6 +196,23 @@ pub async fn scp_download(
157196
Err(Error::SshFailed)
158197
}
159198

199+
async fn command_output(
200+
mut command: Command,
201+
program: &str,
202+
ip: &str,
203+
command_timeout: Duration,
204+
) -> Result<Output, Error> {
205+
command.kill_on_drop(true);
206+
match timeout(command_timeout, command.output()).await {
207+
Ok(output) => Ok(output?),
208+
Err(_) => Err(Error::CommandTimeout {
209+
program: program.to_string(),
210+
ip: ip.to_string(),
211+
seconds: command_timeout.as_secs(),
212+
}),
213+
}
214+
}
215+
160216
/// Converts an IP address to a CIDR block
161217
pub fn exact_cidr(ip: &str) -> String {
162218
format!("{ip}/32")

0 commit comments

Comments
 (0)