11//! Utility functions for interacting with EC2 instances
22
33use crate :: aws:: Error ;
4- use std:: path:: Path ;
4+ use std:: { path:: Path , process :: Output } ;
55use tokio:: {
66 fs:: File ,
77 io:: AsyncWriteExt ,
88 process:: Command ,
9- time:: { sleep, Duration } ,
9+ time:: { sleep, timeout , Duration } ,
1010} ;
1111use tracing:: { info, warn} ;
1212
@@ -19,6 +19,15 @@ pub const MAX_POLL_ATTEMPTS: usize = 30;
1919/// Interval between retries
2020pub const RETRY_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
2121
22+ /// Maximum time to wait for a non-polling SSH command to complete
23+ pub const SSH_COMMAND_TIMEOUT : Duration = Duration :: from_secs ( 30 * 60 ) ;
24+
25+ /// Maximum time to wait for a service status poll to complete
26+ pub const SSH_POLL_TIMEOUT : Duration = Duration :: from_secs ( 30 ) ;
27+
28+ /// Maximum time to wait for an SCP download to complete
29+ pub const SCP_DOWNLOAD_TIMEOUT : Duration = Duration :: from_secs ( 30 * 60 ) ;
30+
2231/// Protocol for deployer ingress
2332pub const DEPLOYER_PROTOCOL : & str = "tcp" ;
2433
@@ -42,9 +51,19 @@ pub async fn get_public_ip() -> Result<String, Error> {
4251
4352/// Executes a command on a remote instance via SSH with retries
4453pub async fn ssh_execute ( key_file : & str , ip : & str , command : & str ) -> Result < ( ) , Error > {
54+ ssh_execute_with_timeout ( key_file, ip, command, SSH_COMMAND_TIMEOUT ) . await
55+ }
56+
57+ /// Executes a command on a remote instance via SSH with retries and a per-attempt timeout
58+ pub async fn ssh_execute_with_timeout (
59+ key_file : & str ,
60+ ip : & str ,
61+ command : & str ,
62+ command_timeout : Duration ,
63+ ) -> Result < ( ) , Error > {
4564 for _ in 0 ..MAX_SSH_ATTEMPTS {
46- let output = Command :: new ( "ssh" )
47- . arg ( "-i" )
65+ let mut cmd = Command :: new ( "ssh" ) ;
66+ cmd . arg ( "-i" )
4867 . arg ( key_file)
4968 . arg ( "-o" )
5069 . arg ( "IdentitiesOnly=yes" )
@@ -53,9 +72,8 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(),
5372 . arg ( "-o" )
5473 . arg ( "StrictHostKeyChecking=no" )
5574 . arg ( format ! ( "ubuntu@{ip}" ) )
56- . arg ( command)
57- . output ( )
58- . await ?;
75+ . arg ( command) ;
76+ let output = command_output ( cmd, "ssh" , ip, command_timeout) . await ?;
5977 if output. status . success ( ) {
6078 return Ok ( ( ) ) ;
6179 }
@@ -68,8 +86,8 @@ pub async fn ssh_execute(key_file: &str, ip: &str, command: &str) -> Result<(),
6886/// Polls the status of a systemd service on a remote instance until active
6987pub async fn poll_service_active ( key_file : & str , ip : & str , service : & str ) -> Result < ( ) , Error > {
7088 for _ in 0 ..MAX_POLL_ATTEMPTS {
71- let output = Command :: new ( "ssh" )
72- . arg ( "-i" )
89+ let mut cmd = Command :: new ( "ssh" ) ;
90+ cmd . arg ( "-i" )
7391 . arg ( key_file)
7492 . arg ( "-o" )
7593 . arg ( "IdentitiesOnly=yes" )
@@ -78,9 +96,16 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res
7896 . arg ( "-o" )
7997 . arg ( "StrictHostKeyChecking=no" )
8098 . arg ( format ! ( "ubuntu@{ip}" ) )
81- . arg ( format ! ( "systemctl is-active {service}" ) )
82- . output ( )
83- . await ?;
99+ . arg ( format ! ( "systemctl is-active {service}" ) ) ;
100+ let output = match command_output ( cmd, "ssh" , ip, SSH_POLL_TIMEOUT ) . await {
101+ Ok ( output) => output,
102+ Err ( err @ Error :: CommandTimeout { .. } ) => {
103+ warn ! ( service, error = ?err, "service status poll timed out" ) ;
104+ sleep ( RETRY_INTERVAL ) . await ;
105+ continue ;
106+ }
107+ Err ( err) => return Err ( err) ,
108+ } ;
84109 let parsed = String :: from_utf8_lossy ( & output. stdout ) ;
85110 let parsed = parsed. trim ( ) ;
86111 if parsed == "active" {
@@ -99,8 +124,8 @@ pub async fn poll_service_active(key_file: &str, ip: &str, service: &str) -> Res
99124/// Polls the status of a systemd service on a remote instance until it becomes inactive
100125pub async fn poll_service_inactive ( key_file : & str , ip : & str , service : & str ) -> Result < ( ) , Error > {
101126 for _ in 0 ..MAX_POLL_ATTEMPTS {
102- let output = Command :: new ( "ssh" )
103- . arg ( "-i" )
127+ let mut cmd = Command :: new ( "ssh" ) ;
128+ cmd . arg ( "-i" )
104129 . arg ( key_file)
105130 . arg ( "-o" )
106131 . arg ( "IdentitiesOnly=yes" )
@@ -109,9 +134,16 @@ pub async fn poll_service_inactive(key_file: &str, ip: &str, service: &str) -> R
109134 . arg ( "-o" )
110135 . arg ( "StrictHostKeyChecking=no" )
111136 . arg ( format ! ( "ubuntu@{ip}" ) )
112- . arg ( format ! ( "systemctl is-active {service}" ) )
113- . output ( )
114- . await ?;
137+ . arg ( format ! ( "systemctl is-active {service}" ) ) ;
138+ let output = match command_output ( cmd, "ssh" , ip, SSH_POLL_TIMEOUT ) . await {
139+ Ok ( output) => output,
140+ Err ( err @ Error :: CommandTimeout { .. } ) => {
141+ warn ! ( service, error = ?err, "service status poll timed out" ) ;
142+ sleep ( RETRY_INTERVAL ) . await ;
143+ continue ;
144+ }
145+ Err ( err) => return Err ( err) ,
146+ } ;
115147 let parsed = String :: from_utf8_lossy ( & output. stdout ) ;
116148 let parsed = parsed. trim ( ) ;
117149 if parsed == "inactive" {
@@ -135,8 +167,8 @@ pub async fn scp_download(
135167 local_path : & str ,
136168) -> Result < ( ) , Error > {
137169 for _ in 0 ..MAX_SSH_ATTEMPTS {
138- let output = Command :: new ( "scp" )
139- . arg ( "-i" )
170+ let mut cmd = Command :: new ( "scp" ) ;
171+ cmd . arg ( "-i" )
140172 . arg ( key_file)
141173 . arg ( "-o" )
142174 . arg ( "IdentitiesOnly=yes" )
@@ -145,9 +177,16 @@ pub async fn scp_download(
145177 . arg ( "-o" )
146178 . arg ( "StrictHostKeyChecking=no" )
147179 . arg ( format ! ( "ubuntu@{ip}:{remote_path}" ) )
148- . arg ( local_path)
149- . output ( )
150- . await ?;
180+ . arg ( local_path) ;
181+ let output = match command_output ( cmd, "scp" , ip, SCP_DOWNLOAD_TIMEOUT ) . await {
182+ Ok ( output) => output,
183+ Err ( err @ Error :: CommandTimeout { .. } ) => {
184+ warn ! ( ip, error = ?err, "SCP timed out" ) ;
185+ sleep ( RETRY_INTERVAL ) . await ;
186+ continue ;
187+ }
188+ Err ( err) => return Err ( err) ,
189+ } ;
151190 if output. status . success ( ) {
152191 return Ok ( ( ) ) ;
153192 }
@@ -157,6 +196,23 @@ pub async fn scp_download(
157196 Err ( Error :: SshFailed )
158197}
159198
199+ async fn command_output (
200+ mut command : Command ,
201+ program : & str ,
202+ ip : & str ,
203+ command_timeout : Duration ,
204+ ) -> Result < Output , Error > {
205+ command. kill_on_drop ( true ) ;
206+ match timeout ( command_timeout, command. output ( ) ) . await {
207+ Ok ( output) => Ok ( output?) ,
208+ Err ( _) => Err ( Error :: CommandTimeout {
209+ program : program. to_string ( ) ,
210+ ip : ip. to_string ( ) ,
211+ seconds : command_timeout. as_secs ( ) ,
212+ } ) ,
213+ }
214+ }
215+
160216/// Converts an IP address to a CIDR block
161217pub fn exact_cidr ( ip : & str ) -> String {
162218 format ! ( "{ip}/32" )
0 commit comments