Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions core/services/sftp/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ use std::io::SeekFrom;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use log::debug;
use openssh::KnownHosts;
use tokio::io::AsyncSeekExt;

use super::SFTP_SCHEME;
use super::config::SftpConfig;
use super::core::SftpConnectionOptions;
use super::core::SftpCore;
use super::deleter::SftpDeleter;
use super::error::is_not_found;
Expand Down Expand Up @@ -55,6 +57,28 @@ pub struct SftpBuilder {
}

impl SftpBuilder {
/// set acquire timeout for pooled sftp connections.
pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
self.config.acquire_timeout = if timeout.is_zero() {
None
} else {
Some(format!("{}s", timeout.as_secs()))
};

self
}

/// set connect timeout for sftp backend.
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.config.connect_timeout = if timeout.is_zero() {
None
} else {
Some(format!("{}s", timeout.as_secs()))
};

self
}

/// set endpoint for sftp backend.
/// The format is same as `openssh`, using either `[user@]hostname` or `ssh://[user@]hostname[:port]`. A username or port that is specified in the endpoint overrides the one set in the builder (but does not change the builder).
pub fn endpoint(mut self, endpoint: &str) -> Self {
Expand Down Expand Up @@ -130,6 +154,9 @@ impl Builder for SftpBuilder {

fn build(self) -> Result<impl Access> {
debug!("sftp backend build started: {:?}", &self);
const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(10);
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10);

let endpoint = match self.config.endpoint.clone() {
Some(v) => v,
None => return Err(Error::new(ErrorKind::ConfigInvalid, "endpoint is empty")),
Expand Down Expand Up @@ -163,6 +190,15 @@ impl Builder for SftpBuilder {
None => KnownHosts::Strict,
};

let acquire_timeout = match self.config.acquire_timeout.as_deref() {
Some(value) => signed_to_duration(value)?,
None => DEFAULT_ACQUIRE_TIMEOUT,
};
let connect_timeout = match self.config.connect_timeout.as_deref() {
Some(value) => signed_to_duration(value)?,
None => DEFAULT_CONNECT_TIMEOUT,
};

let info = AccessorInfo::default();
info.set_root(root.as_str())
.set_scheme(SFTP_SCHEME)
Expand Down Expand Up @@ -193,16 +229,40 @@ impl Builder for SftpBuilder {
Arc::new(info),
endpoint,
root,
user,
self.config.key.clone(),
known_hosts_strategy,
SftpConnectionOptions {
user,
key: self.config.key.clone(),
known_hosts_strategy,
acquire_timeout,
connect_timeout,
},
));

debug!("sftp backend finished: {:?}", &self);
Ok(SftpBackend { core })
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn build_rejects_invalid_timeout() {
let builder = SftpBuilder {
config: SftpConfig {
endpoint: Some("host".to_string()),
connect_timeout: Some("invalid".to_string()),
..Default::default()
},
};

let err = builder.build().unwrap_err();
assert_eq!(err.kind(), ErrorKind::ConfigInvalid);
assert!(err.to_string().contains("failed to parse duration"));
}
}

/// Backend is used to serve `Accessor` support for sftp.
#[derive(Clone, Debug)]
pub struct SftpBackend {
Expand Down
21 changes: 21 additions & 0 deletions core/services/sftp/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ pub struct SftpConfig {
pub key: Option<String>,
/// known_hosts_strategy of this backend
pub known_hosts_strategy: Option<String>,
/// acquire_timeout of this backend
pub acquire_timeout: Option<String>,
/// connect_timeout of this backend
pub connect_timeout: Option<String>,
/// enable_copy of this backend
pub enable_copy: bool,
}
Expand Down Expand Up @@ -108,4 +112,21 @@ mod tests {
assert_eq!(cfg.key.as_deref(), Some("/home/alice/.ssh/id_rsa"));
assert_eq!(cfg.known_hosts_strategy.as_deref(), Some("accept"));
}

#[test]
fn from_uri_applies_timeout_overrides() {
let uri = OperatorUri::new(
"sftp://host",
vec![
("acquire_timeout".to_string(), "5s".to_string()),
("connect_timeout".to_string(), "15s".to_string()),
],
)
.unwrap();

let cfg = SftpConfig::from_uri(&uri).unwrap();
assert_eq!(cfg.endpoint.as_deref(), Some("host"));
assert_eq!(cfg.acquire_timeout.as_deref(), Some("5s"));
assert_eq!(cfg.connect_timeout.as_deref(), Some("15s"));
}
}
144 changes: 128 additions & 16 deletions core/services/sftp/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,32 @@ use std::fmt::Debug;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

pub struct SftpCore {
pub info: Arc<AccessorInfo>,
pub endpoint: String,
pub root: String,
pub acquire_timeout: Duration,
pub connect_timeout: Duration,
client: Arc<bounded::Pool<Manager>>,
}

pub struct SftpConnectionOptions {
pub user: Option<String>,
pub key: Option<String>,
pub known_hosts_strategy: KnownHosts,
pub acquire_timeout: Duration,
pub connect_timeout: Duration,
}

impl Debug for SftpCore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SftpCore")
.field("endpoint", &self.endpoint)
.field("root", &self.root)
.field("acquire_timeout", &self.acquire_timeout)
.field("connect_timeout", &self.connect_timeout)
.finish_non_exhaustive()
}
}
Expand All @@ -52,42 +65,62 @@ impl SftpCore {
info: Arc<AccessorInfo>,
endpoint: String,
root: String,
user: Option<String>,
key: Option<String>,
known_hosts_strategy: KnownHosts,
options: SftpConnectionOptions,
) -> Self {
let client = bounded::Pool::new(
bounded::PoolConfig::new(64),
Manager {
endpoint: endpoint.clone(),
root: root.clone(),
user,
key,
known_hosts_strategy,
user: options.user,
key: options.key,
known_hosts_strategy: options.known_hosts_strategy.clone(),
connect_timeout: options.connect_timeout,
},
);

SftpCore {
info,
endpoint,
root,
acquire_timeout: options.acquire_timeout,
connect_timeout: options.connect_timeout,
client,
}
}

pub async fn connect(&self) -> Result<bounded::Object<Manager>> {
let fut = self.client.get();
acquire_pooled_sftp_connection(&self.client, self.acquire_timeout).await
}
}

tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(10)) => {
Err(Error::new(ErrorKind::Unexpected, "connection request: timeout").set_temporary())
}
result = fut => match result {
Ok(conn) => Ok(conn),
Err(err) => Err(err),
}
}
// Only apply acquire timeout when the pool is saturated. Otherwise `fastpool::get()`
// may perform connection creation, and wrapping that path in a generic timeout would
// hide the underlying SSH/SFTP error again.
async fn acquire_pooled_sftp_connection<M>(
pool: &Arc<bounded::Pool<M>>,
acquire_timeout: Duration,
) -> Result<bounded::Object<M>>
where
M: ManageObject<Error = Error>,
{
let status = pool.status();
let should_timeout = !acquire_timeout.is_zero()
&& status.current_size >= status.max_size
&& status.idle_count == 0;

if should_timeout {
return match tokio::time::timeout(acquire_timeout, pool.get()).await {
Ok(result) => result,
Err(_) => Err(Error::new(
ErrorKind::Unexpected,
"timed out waiting for pooled sftp connection",
)
.set_temporary()),
};
}

pool.get().await
}

pub struct Manager {
Expand All @@ -96,6 +129,7 @@ pub struct Manager {
user: Option<String>,
key: Option<String>,
known_hosts_strategy: KnownHosts,
connect_timeout: Duration,
}

impl ManageObject for Manager {
Expand All @@ -113,6 +147,7 @@ impl ManageObject for Manager {
session.keyfile(key);
}

session.connect_timeout(self.connect_timeout);
session.known_hosts_check(self.known_hosts_strategy.clone());

let session = session
Expand Down Expand Up @@ -159,3 +194,80 @@ impl ManageObject for Manager {
}
}
}

#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};

use super::*;

#[derive(Clone, Default)]
struct TestManager {
create_delay: Duration,
recycle_delay: Duration,
created: Arc<AtomicUsize>,
}

impl ManageObject for TestManager {
type Object = usize;
type Error = Error;

async fn create(&self) -> Result<Self::Object> {
if !self.create_delay.is_zero() {
tokio::time::sleep(self.create_delay).await;
}

Ok(self.created.fetch_add(1, Ordering::SeqCst))
}

async fn is_recyclable(
&self,
_: &mut Self::Object,
_: &ObjectStatus,
) -> Result<(), Self::Error> {
if !self.recycle_delay.is_zero() {
tokio::time::sleep(self.recycle_delay).await;
}

Ok(())
}
}

#[tokio::test]
async fn acquire_timeout_only_applies_when_pool_is_saturated() {
let pool = bounded::Pool::new(
bounded::PoolConfig::new(1),
TestManager {
create_delay: Duration::from_millis(50),
..Default::default()
},
);

let started = std::time::Instant::now();
let conn = acquire_pooled_sftp_connection(&pool, Duration::from_millis(10))
.await
.expect("pool should create a new connection");

assert!(started.elapsed() >= Duration::from_millis(50));
drop(conn);
}

#[tokio::test]
async fn acquire_timeout_reports_waiting_for_pooled_connection() {
let pool = bounded::Pool::new(bounded::PoolConfig::new(1), TestManager::default());
let held = pool.get().await.expect("first connection should succeed");

let err = acquire_pooled_sftp_connection(&pool, Duration::from_millis(20))
.await
.expect_err("second acquire should time out");

assert_eq!(err.kind(), ErrorKind::Unexpected);
assert!(err.is_temporary());
assert!(
err.to_string()
.contains("timed out waiting for pooled sftp connection")
);

drop(held);
}
}
2 changes: 2 additions & 0 deletions core/services/sftp/src/docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ This service can be used to:
- `user`: Set the login user
- `key`: Set the public key for login
- `known_hosts_strategy`: Set the strategy for known hosts, default to `Strict`
- `acquire_timeout`: Set how long to wait for an already-saturated SFTP connection pool, default to `10s`
- `connect_timeout`: Set the SSH connect timeout, default to `10s`
- `enable_copy`: Set whether the remote server has copy-file extension

For security reasons, it doesn't support password login, you can use public key or ssh-copy-id instead.
Expand Down
2 changes: 1 addition & 1 deletion core/services/sftp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub fn parse_sftp_error(e: SftpClientError) -> Error {
}

pub fn parse_ssh_error(e: SshError) -> Error {
Error::new(ErrorKind::Unexpected, "ssh error").set_source(e)
Error::new(ErrorKind::Unexpected, "failed to establish ssh connection").set_source(e)
}

pub(super) fn is_not_found(e: &SftpClientError) -> bool {
Expand Down
Loading