diff --git a/core/services/sftp/src/backend.rs b/core/services/sftp/src/backend.rs index 272c02ba6a8f..ed5879ffc2b4 100644 --- a/core/services/sftp/src/backend.rs +++ b/core/services/sftp/src/backend.rs @@ -19,6 +19,7 @@ use std::io::SeekFrom; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; use log::debug; use openssh::KnownHosts; @@ -26,6 +27,7 @@ use tokio::io::AsyncSeekExt; use super::SFTP_SCHEME; use super::config::SftpConfig; +use super::core::SftpConnectionOptions; use super::core::SftpCore; use super::deleter::SftpDeleter; use super::error::is_not_found; @@ -55,6 +57,28 @@ pub struct SftpBuilder { } impl SftpBuilder { + /// set acquire timeout for pooled sftp connections. + pub fn acquire_timeout(mut self, timeout: Duration) -> Self { + self.config.acquire_timeout = if timeout.is_zero() { + None + } else { + Some(format!("{}s", timeout.as_secs())) + }; + + self + } + + /// set connect timeout for sftp backend. + pub fn connect_timeout(mut self, timeout: Duration) -> Self { + self.config.connect_timeout = if timeout.is_zero() { + None + } else { + Some(format!("{}s", timeout.as_secs())) + }; + + self + } + /// set endpoint for sftp backend. /// The format is same as `openssh`, using either `[user@]hostname` or `ssh://[user@]hostname[:port]`. A username or port that is specified in the endpoint overrides the one set in the builder (but does not change the builder). pub fn endpoint(mut self, endpoint: &str) -> Self { @@ -130,6 +154,9 @@ impl Builder for SftpBuilder { fn build(self) -> Result { debug!("sftp backend build started: {:?}", &self); + const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(10); + const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); + let endpoint = match self.config.endpoint.clone() { Some(v) => v, None => return Err(Error::new(ErrorKind::ConfigInvalid, "endpoint is empty")), @@ -163,6 +190,15 @@ impl Builder for SftpBuilder { None => KnownHosts::Strict, }; + let acquire_timeout = match self.config.acquire_timeout.as_deref() { + Some(value) => signed_to_duration(value)?, + None => DEFAULT_ACQUIRE_TIMEOUT, + }; + let connect_timeout = match self.config.connect_timeout.as_deref() { + Some(value) => signed_to_duration(value)?, + None => DEFAULT_CONNECT_TIMEOUT, + }; + let info = AccessorInfo::default(); info.set_root(root.as_str()) .set_scheme(SFTP_SCHEME) @@ -193,9 +229,13 @@ impl Builder for SftpBuilder { Arc::new(info), endpoint, root, - user, - self.config.key.clone(), - known_hosts_strategy, + SftpConnectionOptions { + user, + key: self.config.key.clone(), + known_hosts_strategy, + acquire_timeout, + connect_timeout, + }, )); debug!("sftp backend finished: {:?}", &self); @@ -203,6 +243,26 @@ impl Builder for SftpBuilder { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_rejects_invalid_timeout() { + let builder = SftpBuilder { + config: SftpConfig { + endpoint: Some("host".to_string()), + connect_timeout: Some("invalid".to_string()), + ..Default::default() + }, + }; + + let err = builder.build().unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ConfigInvalid); + assert!(err.to_string().contains("failed to parse duration")); + } +} + /// Backend is used to serve `Accessor` support for sftp. #[derive(Clone, Debug)] pub struct SftpBackend { diff --git a/core/services/sftp/src/config.rs b/core/services/sftp/src/config.rs index 48445770a527..9112f33313c0 100644 --- a/core/services/sftp/src/config.rs +++ b/core/services/sftp/src/config.rs @@ -37,6 +37,10 @@ pub struct SftpConfig { pub key: Option, /// known_hosts_strategy of this backend pub known_hosts_strategy: Option, + /// acquire_timeout of this backend + pub acquire_timeout: Option, + /// connect_timeout of this backend + pub connect_timeout: Option, /// enable_copy of this backend pub enable_copy: bool, } @@ -108,4 +112,21 @@ mod tests { assert_eq!(cfg.key.as_deref(), Some("/home/alice/.ssh/id_rsa")); assert_eq!(cfg.known_hosts_strategy.as_deref(), Some("accept")); } + + #[test] + fn from_uri_applies_timeout_overrides() { + let uri = OperatorUri::new( + "sftp://host", + vec![ + ("acquire_timeout".to_string(), "5s".to_string()), + ("connect_timeout".to_string(), "15s".to_string()), + ], + ) + .unwrap(); + + let cfg = SftpConfig::from_uri(&uri).unwrap(); + assert_eq!(cfg.endpoint.as_deref(), Some("host")); + assert_eq!(cfg.acquire_timeout.as_deref(), Some("5s")); + assert_eq!(cfg.connect_timeout.as_deref(), Some("15s")); + } } diff --git a/core/services/sftp/src/core.rs b/core/services/sftp/src/core.rs index bb8e24bb3a93..b1393d2218df 100644 --- a/core/services/sftp/src/core.rs +++ b/core/services/sftp/src/core.rs @@ -30,19 +30,32 @@ use std::fmt::Debug; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; pub struct SftpCore { pub info: Arc, pub endpoint: String, pub root: String, + pub acquire_timeout: Duration, + pub connect_timeout: Duration, client: Arc>, } +pub struct SftpConnectionOptions { + pub user: Option, + pub key: Option, + pub known_hosts_strategy: KnownHosts, + pub acquire_timeout: Duration, + pub connect_timeout: Duration, +} + impl Debug for SftpCore { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SftpCore") .field("endpoint", &self.endpoint) .field("root", &self.root) + .field("acquire_timeout", &self.acquire_timeout) + .field("connect_timeout", &self.connect_timeout) .finish_non_exhaustive() } } @@ -52,18 +65,17 @@ impl SftpCore { info: Arc, endpoint: String, root: String, - user: Option, - key: Option, - known_hosts_strategy: KnownHosts, + options: SftpConnectionOptions, ) -> Self { let client = bounded::Pool::new( bounded::PoolConfig::new(64), Manager { endpoint: endpoint.clone(), root: root.clone(), - user, - key, - known_hosts_strategy, + user: options.user, + key: options.key, + known_hosts_strategy: options.known_hosts_strategy.clone(), + connect_timeout: options.connect_timeout, }, ); @@ -71,23 +83,44 @@ impl SftpCore { info, endpoint, root, + acquire_timeout: options.acquire_timeout, + connect_timeout: options.connect_timeout, client, } } pub async fn connect(&self) -> Result> { - let fut = self.client.get(); + acquire_pooled_sftp_connection(&self.client, self.acquire_timeout).await + } +} - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(10)) => { - Err(Error::new(ErrorKind::Unexpected, "connection request: timeout").set_temporary()) - } - result = fut => match result { - Ok(conn) => Ok(conn), - Err(err) => Err(err), - } - } +// Only apply acquire timeout when the pool is saturated. Otherwise `fastpool::get()` +// may perform connection creation, and wrapping that path in a generic timeout would +// hide the underlying SSH/SFTP error again. +async fn acquire_pooled_sftp_connection( + pool: &Arc>, + acquire_timeout: Duration, +) -> Result> +where + M: ManageObject, +{ + let status = pool.status(); + let should_timeout = !acquire_timeout.is_zero() + && status.current_size >= status.max_size + && status.idle_count == 0; + + if should_timeout { + return match tokio::time::timeout(acquire_timeout, pool.get()).await { + Ok(result) => result, + Err(_) => Err(Error::new( + ErrorKind::Unexpected, + "timed out waiting for pooled sftp connection", + ) + .set_temporary()), + }; } + + pool.get().await } pub struct Manager { @@ -96,6 +129,7 @@ pub struct Manager { user: Option, key: Option, known_hosts_strategy: KnownHosts, + connect_timeout: Duration, } impl ManageObject for Manager { @@ -113,6 +147,7 @@ impl ManageObject for Manager { session.keyfile(key); } + session.connect_timeout(self.connect_timeout); session.known_hosts_check(self.known_hosts_strategy.clone()); let session = session @@ -159,3 +194,80 @@ impl ManageObject for Manager { } } } + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + + use super::*; + + #[derive(Clone, Default)] + struct TestManager { + create_delay: Duration, + recycle_delay: Duration, + created: Arc, + } + + impl ManageObject for TestManager { + type Object = usize; + type Error = Error; + + async fn create(&self) -> Result { + if !self.create_delay.is_zero() { + tokio::time::sleep(self.create_delay).await; + } + + Ok(self.created.fetch_add(1, Ordering::SeqCst)) + } + + async fn is_recyclable( + &self, + _: &mut Self::Object, + _: &ObjectStatus, + ) -> Result<(), Self::Error> { + if !self.recycle_delay.is_zero() { + tokio::time::sleep(self.recycle_delay).await; + } + + Ok(()) + } + } + + #[tokio::test] + async fn acquire_timeout_only_applies_when_pool_is_saturated() { + let pool = bounded::Pool::new( + bounded::PoolConfig::new(1), + TestManager { + create_delay: Duration::from_millis(50), + ..Default::default() + }, + ); + + let started = std::time::Instant::now(); + let conn = acquire_pooled_sftp_connection(&pool, Duration::from_millis(10)) + .await + .expect("pool should create a new connection"); + + assert!(started.elapsed() >= Duration::from_millis(50)); + drop(conn); + } + + #[tokio::test] + async fn acquire_timeout_reports_waiting_for_pooled_connection() { + let pool = bounded::Pool::new(bounded::PoolConfig::new(1), TestManager::default()); + let held = pool.get().await.expect("first connection should succeed"); + + let err = acquire_pooled_sftp_connection(&pool, Duration::from_millis(20)) + .await + .expect_err("second acquire should time out"); + + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.is_temporary()); + assert!( + err.to_string() + .contains("timed out waiting for pooled sftp connection") + ); + + drop(held); + } +} diff --git a/core/services/sftp/src/docs.md b/core/services/sftp/src/docs.md index 123a167f7196..1aa1ceed86d1 100644 --- a/core/services/sftp/src/docs.md +++ b/core/services/sftp/src/docs.md @@ -19,6 +19,8 @@ This service can be used to: - `user`: Set the login user - `key`: Set the public key for login - `known_hosts_strategy`: Set the strategy for known hosts, default to `Strict` +- `acquire_timeout`: Set how long to wait for an already-saturated SFTP connection pool, default to `10s` +- `connect_timeout`: Set the SSH connect timeout, default to `10s` - `enable_copy`: Set whether the remote server has copy-file extension For security reasons, it doesn't support password login, you can use public key or ssh-copy-id instead. diff --git a/core/services/sftp/src/error.rs b/core/services/sftp/src/error.rs index f05c9a811b67..7b3026e2002f 100644 --- a/core/services/sftp/src/error.rs +++ b/core/services/sftp/src/error.rs @@ -45,7 +45,7 @@ pub fn parse_sftp_error(e: SftpClientError) -> Error { } pub fn parse_ssh_error(e: SshError) -> Error { - Error::new(ErrorKind::Unexpected, "ssh error").set_source(e) + Error::new(ErrorKind::Unexpected, "failed to establish ssh connection").set_source(e) } pub(super) fn is_not_found(e: &SftpClientError) -> bool {