Skip to content

Commit 672a4a9

Browse files
authored
Merge pull request #26 from AvivNaaman/25-support-receive-timeout
#25: support timeout in worker recv, use in tests/cli.
2 parents b5dca07 + 0b9f11d commit 672a4a9

File tree

12 files changed

+142
-25
lines changed

12 files changed

+142
-25
lines changed

smb-cli/src/path.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ impl UncPath {
1919
cli: &Cli,
2020
) -> Result<(Connection, Session, Tree, Option<Resource>), Box<dyn Error>> {
2121
let mut smb = Connection::new();
22+
smb.set_timeout(Some(std::time::Duration::from_secs(10)))
23+
.await?;
2224
smb.connect(format!("{}:{}", self.server, cli.port).as_str())
2325
.await?;
2426
let mut session = smb

smb/src/connection.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,35 @@ use netbios_client::NetBiosClient;
2525
use std::cmp::max;
2626
use std::sync::atomic::{AtomicU16, AtomicU64};
2727
use std::sync::Arc;
28+
use std::time::Duration;
2829
pub use transformer::TransformError;
2930
use worker::{Worker, WorkerImpl};
3031

3132
pub struct Connection {
3233
handler: HandlerReference<ConnectionMessageHandler>,
34+
timeout: Option<std::time::Duration>,
3335
}
3436

3537
impl Connection {
3638
pub fn new() -> Connection {
3739
Connection {
3840
handler: HandlerReference::new(ConnectionMessageHandler::new()),
41+
timeout: None,
3942
}
4043
}
44+
45+
#[maybe_async]
46+
pub async fn set_timeout(&mut self, timeout: Option<Duration>) -> crate::Result<()> {
47+
self.timeout = timeout;
48+
if let Some(worker) = self.handler.worker.get() {
49+
worker.set_timeout(timeout).await?;
50+
}
51+
Ok(())
52+
}
53+
4154
#[maybe_async]
4255
pub async fn connect(&mut self, address: &str) -> crate::Result<()> {
43-
let mut netbios_client = NetBiosClient::new();
56+
let mut netbios_client = NetBiosClient::new(self.timeout);
4457

4558
log::debug!("Connecting to {}...", address);
4659
netbios_client.connect(address).await?;
@@ -105,7 +118,7 @@ impl Connection {
105118
}
106119
}
107120

108-
Ok(WorkerImpl::start(netbios_client).await?)
121+
Ok(WorkerImpl::start(netbios_client, self.timeout).await?)
109122
}
110123

111124
#[maybe_async]

smb/src/connection/netbios_client.rs

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
use maybe_async::*;
2-
use std::io::Cursor;
2+
use std::{io::Cursor, time::Duration};
33

44
#[cfg(feature = "sync")]
55
use std::{
66
io::{self, Read, Write},
7-
net::TcpStream,
7+
net::{TcpStream, ToSocketAddrs},
88
};
99
#[cfg(feature = "async")]
1010
use tokio::{
1111
io::{self, AsyncReadExt, AsyncWriteExt},
1212
net::{tcp, TcpStream},
13+
select,
1314
};
1415

1516
use binrw::prelude::*;
@@ -37,26 +38,57 @@ type TcpWrite = TcpStream;
3738
pub struct NetBiosClient {
3839
reader: Option<TcpRead>,
3940
writer: Option<TcpWrite>,
41+
timeout: Option<Duration>,
4042
}
4143

4244
impl NetBiosClient {
43-
pub fn new() -> NetBiosClient {
45+
pub fn new(timeout: Option<Duration>) -> NetBiosClient {
4446
NetBiosClient {
4547
reader: None,
4648
writer: None,
49+
timeout,
4750
}
4851
}
4952

5053
/// Connects to a NetBios server in the specified address.
5154
#[maybe_async]
5255
pub async fn connect(&mut self, address: &str) -> crate::Result<()> {
53-
let socket = TcpStream::connect(address).await?;
56+
let socket = self.connect_timeout(address).await?;
5457
let (r, w) = Self::split_socket(socket);
5558
self.reader = Some(r);
5659
self.writer = Some(w);
5760
Ok(())
5861
}
5962

63+
#[cfg(feature = "sync")]
64+
fn connect_timeout(&mut self, address: &str) -> crate::Result<TcpStream> {
65+
if let Some(t) = self.timeout {
66+
log::debug!("Connecting to {} with timeout {:?}.", address, t);
67+
// convert to SocketAddr:
68+
let address = address
69+
.to_socket_addrs()?
70+
.next()
71+
.ok_or(crate::Error::InvalidAddress(address.to_string()))?;
72+
TcpStream::connect_timeout(&address, t).map_err(Into::into)
73+
} else {
74+
log::debug!("Connecting to {}.", address);
75+
TcpStream::connect(&address).map_err(Into::into)
76+
}
77+
}
78+
79+
#[cfg(feature = "async")]
80+
async fn connect_timeout(&mut self, address: &str) -> crate::Result<TcpStream> {
81+
if let None = self.timeout {
82+
log::debug!("Connecting to {}.", address);
83+
return TcpStream::connect(&address).await.map_err(Into::into);
84+
}
85+
86+
select! {
87+
res = TcpStream::connect(&address) => res.map_err(Into::into),
88+
_ = tokio::time::sleep(self.timeout.unwrap()) => Err(crate::Error::OperationTimeout("Tcp connect".to_string(), self.timeout.unwrap())),
89+
}
90+
}
91+
6092
pub fn is_connected(&self) -> bool {
6193
self.reader.is_some()
6294
}
@@ -197,10 +229,12 @@ impl NetBiosClient {
197229
NetBiosClient {
198230
reader: self.reader,
199231
writer: None,
232+
timeout: self.timeout,
200233
},
201234
NetBiosClient {
202235
reader: None,
203236
writer: self.writer,
237+
timeout: self.timeout,
204238
},
205239
))
206240
}

smb/src/connection/worker/multi_worker/async_backend.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::sync_helpers::*;
22
use std::sync::Arc;
3+
use std::time::Duration;
34
use tokio::{select, sync::oneshot};
45

56
use crate::{msg_handler::IncomingMessage, packets::netbios::NetBiosTcpMessage, Error};
@@ -190,10 +191,25 @@ impl MultiWorkerBackend for AsyncBackend {
190191
oneshot::channel()
191192
}
192193

193-
async fn wait_on_waiter(waiter: Self::AwaitingWaiter) -> crate::Result<IncomingMessage> {
194-
waiter
195-
.await
196-
.map_err(|_| Error::MessageProcessingError("Failed to receive message.".to_string()))?
194+
async fn wait_on_waiter(
195+
waiter: Self::AwaitingWaiter,
196+
timeout: Option<Duration>,
197+
) -> crate::Result<IncomingMessage> {
198+
match timeout {
199+
Some(timeout) => {
200+
tokio::select! {
201+
msg = waiter => {
202+
msg.map_err(|_| Error::MessageProcessingError("Failed to receive message.".to_string()))?
203+
},
204+
_ = tokio::time::sleep(timeout) => {
205+
Err(Error::OperationTimeout("Waiting for message receive.".to_string(), timeout))
206+
}
207+
}
208+
}
209+
None => waiter.await.map_err(|_| {
210+
Error::MessageProcessingError("Failed to receive message.".to_string())
211+
})?,
212+
}
197213
}
198214

199215
fn send_notify(

smb/src/connection/worker/multi_worker/backend_trait.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::connection::netbios_client::NetBiosClient;
22
use crate::sync_helpers::*;
33
use maybe_async::*;
4-
use std::sync::Arc;
4+
use std::{sync::Arc, time::Duration};
55

66
use crate::{msg_handler::IncomingMessage, packets::netbios::NetBiosTcpMessage};
77

@@ -32,7 +32,10 @@ pub trait MultiWorkerBackend {
3232
mpsc::Receiver<Self::SendMessage>,
3333
);
3434

35-
async fn wait_on_waiter(waiter: Self::AwaitingWaiter) -> crate::Result<IncomingMessage>;
35+
async fn wait_on_waiter(
36+
waiter: Self::AwaitingWaiter,
37+
timeout: Option<Duration>,
38+
) -> crate::Result<IncomingMessage>;
3639
fn send_notify(
3740
tx: Self::AwaitingNotifier,
3841
msg: crate::Result<IncomingMessage>,

smb/src/connection/worker/multi_worker/base.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::connection::worker::Worker;
44
use crate::sync_helpers::*;
55
use maybe_async::*;
66
use std::sync::atomic::AtomicBool;
7+
use std::time::Duration;
78
use std::{collections::HashMap, sync::Arc};
89

910
use crate::{
@@ -34,6 +35,9 @@ where
3435
/// A channel to send messages to the worker.
3536
pub(crate) sender: mpsc::Sender<T::SendMessage>,
3637
stopped: AtomicBool,
38+
39+
/// atomic duration:
40+
timeout: RwLock<Option<Duration>>,
3741
}
3842

3943
/// Holds state for the worker, regarding messages to be received.
@@ -140,7 +144,10 @@ where
140144
T::AwaitingNotifier: std::fmt::Debug,
141145
{
142146
#[maybe_async]
143-
async fn start(netbios_client: NetBiosClient) -> crate::Result<Arc<Self>> {
147+
async fn start(
148+
netbios_client: NetBiosClient,
149+
timeout: Option<Duration>,
150+
) -> crate::Result<Arc<Self>> {
144151
// Build the worker
145152
let (tx, rx) = T::make_send_channel_pair();
146153
let worker = Arc::new(MultiWorkerBase::<T> {
@@ -149,6 +156,7 @@ where
149156
transformer: Transformer::default(),
150157
sender: tx,
151158
stopped: AtomicBool::new(false),
159+
timeout: RwLock::new(timeout),
152160
});
153161

154162
worker
@@ -235,7 +243,8 @@ where
235243
rx
236244
};
237245

238-
let wait_result = T::wait_on_waiter(wait_for_receive).await;
246+
let timeout = { *self.timeout.read().await? };
247+
let wait_result = T::wait_on_waiter(wait_for_receive, timeout).await;
239248

240249
// Wait for the message to be received.
241250
Ok(wait_result.map_err(|_| {
@@ -246,6 +255,12 @@ where
246255
fn transformer(&self) -> &Transformer {
247256
&self.transformer
248257
}
258+
259+
#[maybe_async]
260+
async fn set_timeout(&self, timeout: Option<Duration>) -> crate::Result<()> {
261+
*self.timeout.write().await? = timeout;
262+
Ok(())
263+
}
249264
}
250265

251266
impl<T> std::fmt::Debug for MultiWorkerBase<T>

smb/src/connection/worker/multi_worker/threading_backend.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,22 @@ impl MultiWorkerBackend for ThreadingBackend {
171171
std::sync::mpsc::channel()
172172
}
173173

174-
fn wait_on_waiter(waiter: Self::AwaitingWaiter) -> crate::Result<IncomingMessage> {
175-
waiter
176-
.recv()
177-
.map_err(|_| Error::MessageProcessingError("Failed to receive message.".to_string()))?
174+
fn wait_on_waiter(
175+
waiter: Self::AwaitingWaiter,
176+
timeout: Option<Duration>,
177+
) -> crate::Result<IncomingMessage> {
178+
if let None = timeout {
179+
return waiter.recv().map_err(|_| {
180+
Error::MessageProcessingError("Failed to receive message.".to_string())
181+
})?;
182+
}
183+
waiter.recv_timeout(timeout.unwrap()).map_err(|e| match e {
184+
std::sync::mpsc::RecvTimeoutError::Timeout => Error::OperationTimeout(
185+
"Waiting for message receive.".to_string(),
186+
timeout.unwrap(),
187+
),
188+
_ => Error::MessageProcessingError("Failed to receive message.".to_string()),
189+
})?
178190
}
179191

180192
fn send_notify(

smb/src/connection/worker/single_worker.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::{cell::RefCell, sync::Arc};
1+
use std::{cell::RefCell, sync::Arc, time::Duration};
22

33
use crate::{
44
connection::{netbios_client::NetBiosClient, transformer::Transformer},
@@ -18,10 +18,11 @@ pub struct SingleWorker {
1818
}
1919

2020
impl Worker for SingleWorker {
21-
fn start(netbios_client: NetBiosClient) -> crate::Result<Arc<Self>> {
21+
fn start(netbios_client: NetBiosClient, timeout: Option<Duration>) -> crate::Result<Arc<Self>> {
2222
if !netbios_client.is_connected() {
2323
Err(crate::Error::NotConnected)
2424
} else {
25+
netbios_client.set_read_timeout(timeout)?;
2526
Ok(Arc::new(Self {
2627
netbios_client: RefCell::new(netbios_client),
2728
transformer: Transformer::default(),
@@ -66,4 +67,8 @@ impl Worker for SingleWorker {
6667
fn transformer(&self) -> &Transformer {
6768
&self.transformer
6869
}
70+
71+
fn set_timeout(&self, timeout: Option<Duration>) -> crate::Result<()> {
72+
self.netbios_client.borrow_mut().set_read_timeout(timeout)
73+
}
6974
}

smb/src/connection/worker/worker_trait.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::{sync::Arc, time::Duration};
22

33
use crate::sync_helpers::*;
44

@@ -21,10 +21,16 @@ use crate::{
2121
#[allow(async_fn_in_trait)]
2222
pub trait Worker: Sized + std::fmt::Debug {
2323
/// Instantiates a new connection worker.
24-
async fn start(netbios_client: NetBiosClient) -> crate::Result<Arc<Self>>;
24+
async fn start(
25+
netbios_client: NetBiosClient,
26+
timeout: Option<Duration>,
27+
) -> crate::Result<Arc<Self>>;
2528
/// Stops the worker, shutting down the connection.
2629
async fn stop(&self) -> crate::Result<()>;
2730

31+
/// Sets the timeout for the worker.
32+
async fn set_timeout(&self, timeout: Option<Duration>) -> crate::Result<()>;
33+
2834
async fn send(self: &Self, msg: OutgoingMessage) -> crate::Result<SendMessageResult>;
2935
/// Receive a message from the server.
3036
/// This is a user function that will wait for the message to be received.

smb/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ pub enum Error {
5252
UsernameError(String),
5353
#[error("Message processing failed. {0}")]
5454
MessageProcessingError(String),
55+
#[error("Operation timed out: {0}, took >{1:?}")]
56+
OperationTimeout(String, std::time::Duration),
5557
#[error("Lock error.")]
5658
LockError,
5759
#[cfg(feature = "async")]
@@ -69,6 +71,8 @@ pub enum Error {
6971
UnexpectedMessageId(u64, u64),
7072
#[error("Expected info of type {0} but got {1}")]
7173
UnexpectedInformationType(u8, u8),
74+
#[error("Invalid address {0}")]
75+
InvalidAddress(String),
7276
}
7377

7478
impl<T> From<PoisonError<T>> for Error {

0 commit comments

Comments
 (0)