Skip to content

Commit 6f7acc6

Browse files
authored
Merge pull request #106 from silicon-heaven/mocking-runtime-agnostic
Make mocking async runtime agnostic
2 parents 2e44fa1 + 24cefb6 commit 6f7acc6

3 files changed

Lines changed: 88 additions & 50 deletions

File tree

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "shvclient"
33
description = "A Rust framework for Silicon Heaven RPC devices"
44
license = "MIT"
55
repository = "https://github.com/silicon-heaven/libshvclient-rs"
6-
version = "4.2.0"
6+
version = "5.0.0"
77
edition = "2024"
88

99
[lib]
@@ -55,6 +55,8 @@ futures-rustls = "0.26.0"
5555
rustls-pemfile = "2.2.0"
5656
rustls-platform-verifier = "0.6.1"
5757
async-trait = "0.1.89"
58+
futures-timer = "3.0.3"
59+
event-listener = { version = "5.4.1", optional = true }
5860

5961
# For local development
6062
# [patch.crates-io]
@@ -65,7 +67,7 @@ async-trait = "0.1.89"
6567
default = []
6668
tokio = ["dep:tokio", "dep:tokio-util"]
6769
smol = ["dep:smol"]
68-
mocking = []
70+
mocking = ["dep:event-listener"]
6971

7072
[lints.clippy]
7173
pedantic = {priority = -1, level = "warn"}

src/mocking.rs

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use std::{error::Error, sync::Arc, time::Duration};
2-
use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}};
1+
use std::{error::Error, sync::{Arc, RwLock}, time::Duration};
2+
use event_listener::Event;
3+
use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, future::{Either, select}};
4+
use futures_timer::Delay;
35
use log::debug;
46
use shvproto::RpcValue;
57
use shvrpc::{RpcMessage, RpcMessageMetaTags, rpc::ShvRI, rpcmessage::{RpcError, RpcErrorCode}};
6-
use tokio::{sync::{Notify, RwLock}, task::JoinHandle};
7-
use crate::{ConnectionCommand, ConnectionEvent};
8+
use crate::{ConnectionCommand, ConnectionEvent, runtime::TaskHandle};
89

910
const TIMEOUT_DURATION: Duration = Duration::from_secs(5);
1011

@@ -69,11 +70,11 @@ where
6970
}
7071

7172
pub struct TestApp {
72-
app: JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>,
73-
msg_task: JoinHandle<()>,
73+
app: TaskHandle<Result<(), Box<dyn Error + Send + Sync>>>,
74+
msg_task: TaskHandle<()>,
7475
conn_evt_tx: UnboundedSender<ConnectionEvent>,
7576
pending_msgs: Arc<RwLock<Vec<RpcMessage>>>,
76-
notifier: Arc<Notify>,
77+
notifier: Arc<Event>,
7778
}
7879

7980
impl TestApp {
@@ -95,17 +96,18 @@ impl TestApp {
9596
F: Fn(&RpcMessage) -> bool
9697
{
9798
loop {
98-
let waiter = self.notifier.notified();
99+
let event_listener = self.notifier.listen();
99100

100-
let mut pending = self.pending_msgs.write().await;
101-
let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos));
101+
{
102+
let mut pending = self.pending_msgs.write().expect("pending_msgs write should succeed");
103+
let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos));
102104

103-
if let Some(matched_msg) = matched_msg {
104-
return matched_msg;
105+
if let Some(matched_msg) = matched_msg {
106+
return matched_msg;
107+
}
105108
}
106109

107-
drop(pending);
108-
tokio::time::timeout(TIMEOUT_DURATION, waiter).await.unwrap_or_else(|_err| panic!("Timed out while waiting for message"));
110+
timeout(TIMEOUT_DURATION, event_listener).await.unwrap_or_else(|err| panic!("{err}"));
109111
}
110112
}
111113

@@ -175,18 +177,18 @@ impl TestApp {
175177
.respond(Ok("true"));
176178
}
177179

178-
pub fn new(app_maker: impl FnOnce(UnboundedReceiver<ConnectionEvent>) -> JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>) -> Self {
180+
pub fn new(app_maker: impl FnOnce(UnboundedReceiver<ConnectionEvent>) -> TaskHandle<Result<(), Box<dyn Error + Send + Sync>>>) -> Self {
179181
let (conn_evt_tx, conn_evt_rx) = futures::channel::mpsc::unbounded::<ConnectionEvent>();
180182
let app = app_maker(conn_evt_rx);
181183

182184
let (conn_cmd_sender_in, mut conn_cmd_receiver_in) = unbounded();
183185
let pending_msgs = Arc::new(RwLock::new(Vec::new()));
184-
let notifier = Arc::new(Notify::new());
186+
let notifier = Arc::new(Event::new());
185187

186188
let msg_task = {
187189
let pending_msgs = pending_msgs.clone();
188190
let notifier = notifier.clone();
189-
tokio::spawn(async move {
191+
crate::runtime::spawn_task(async move {
190192
while let Some(ConnectionCommand::SendMessage(rpc_message)) = conn_cmd_receiver_in.next().await {
191193
let shv_path = rpc_message.shv_path().unwrap_or_default();
192194
let method = rpc_message.method().unwrap_or_default();
@@ -200,9 +202,9 @@ impl TestApp {
200202
}
201203

202204
{
203-
pending_msgs.write().await.push(rpc_message);
205+
pending_msgs.write().expect("pending_msgs write should succeed").push(rpc_message);
204206
}
205-
notifier.notify_waiters();
207+
notifier.notify(usize::MAX);
206208
}
207209
})
208210
};
@@ -218,30 +220,33 @@ impl TestApp {
218220

219221
pub async fn wait_until_finished(self) -> shvrpc::Result<()> {
220222
// Wait for a bit to ensure silence from the app.
221-
tokio::time::sleep(Duration::from_millis(500)).await;
223+
futures_timer::Delay::new(Duration::from_millis(500)).await;
222224
self.conn_evt_tx.close_channel();
223-
let mut pending_msgs = self.pending_msgs.write().await;
224-
// We'll let unsubscribe calls slide.
225-
pending_msgs.retain(|msg|
226-
!msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe")
227-
);
228-
for rpc_message in pending_msgs.iter() {
229-
let shv_path = rpc_message.shv_path().unwrap_or_default();
230-
let method = rpc_message.method().unwrap_or_default();
231-
let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);
232-
let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} "));
233-
if rpc_message.is_response() {
234-
let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported"));
235-
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}");
236-
} else if method.is_empty() && shv_path.is_empty() {
237-
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}");
238-
} else {
239-
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}");
240-
}
241-
};
242225

243-
if !pending_msgs.is_empty() {
244-
return Err("There were unexpected messages received from the app.".into());
226+
{
227+
let mut pending_msgs = self.pending_msgs.write().expect("pending_msgs write should succeed");
228+
// We'll let unsubscribe calls slide.
229+
pending_msgs.retain(|msg|
230+
!msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe")
231+
);
232+
for rpc_message in pending_msgs.iter() {
233+
let shv_path = rpc_message.shv_path().unwrap_or_default();
234+
let method = rpc_message.method().unwrap_or_default();
235+
let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);
236+
let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} "));
237+
if rpc_message.is_response() {
238+
let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported"));
239+
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}");
240+
} else if method.is_empty() && shv_path.is_empty() {
241+
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}");
242+
} else {
243+
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}");
244+
}
245+
};
246+
247+
if !pending_msgs.is_empty() {
248+
return Err("There were unexpected messages received from the app.".into());
249+
}
245250
}
246251

247252
let end = async move {
@@ -251,7 +256,20 @@ impl TestApp {
251256
Ok(())
252257
};
253258

254-
tokio::time::timeout(TIMEOUT_DURATION, end).await?
259+
timeout(TIMEOUT_DURATION, end).await?
255260
}
256261
}
257262

263+
pub async fn timeout<F, T>(dur: Duration, fut: F) -> shvrpc::Result<T>
264+
where
265+
F: Future<Output = T>,
266+
{
267+
futures::pin_mut!(fut);
268+
let timeout = Delay::new(dur);
269+
futures::pin_mut!(timeout);
270+
271+
match select(fut, timeout).await {
272+
Either::Left((val, _)) => Ok(val),
273+
Either::Right(_) => Err("Timed out while waiting for a future".into()),
274+
}
275+
}

src/runtime.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,29 @@
1+
12
#[must_use = "Task has to be used. If you want to detach the task, call .detach() on it."]
2-
pub struct TaskHandle<F: futures::Future + Send + 'static>(
3+
pub struct TaskHandle<O: Send + 'static>(
4+
// pub struct TaskHandle<F: futures::Future + Send + 'static>(
35
#[cfg(feature = "tokio")]
4-
pub tokio::task::JoinHandle<F::Output>,
6+
pub tokio::task::JoinHandle<O>,
57
#[cfg(feature = "smol")]
6-
pub smol::Task<F::Output>,
8+
// The error type is dummy as smol::Task future resolves to the result right away,
9+
// but we want to keep the same API with tokio JoinHandle, which returns a Result.
10+
pub smol::Task<Result<O, std::convert::Infallible>>,
711
);
812

9-
impl<F: futures::Future + Send + 'static> TaskHandle<F> {
13+
impl<O: Send + 'static> Future for TaskHandle<O> {
14+
#[cfg(feature = "tokio")]
15+
type Output = Result<O, tokio::task::JoinError>;
16+
17+
#[cfg(feature = "smol")]
18+
type Output = Result<O, std::convert::Infallible>;
19+
20+
fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
21+
std::pin::Pin::new(&mut self.0).poll(cx)
22+
}
23+
}
24+
25+
// impl<F: futures::Future + Send + 'static> TaskHandle<F> {
26+
impl<O: Send + 'static> TaskHandle<O> {
1027
#[cfg(feature = "tokio")]
1128
#[expect(clippy::unused_async, reason = "We want the same API as with smol")]
1229
pub async fn cancel(self) {
@@ -24,15 +41,16 @@ impl<F: futures::Future + Send + 'static> TaskHandle<F> {
2441
}
2542
}
2643

27-
pub fn spawn_task<F>(f: F) -> TaskHandle<F>
44+
pub fn spawn_task<F>(f: F) -> TaskHandle<F::Output>
2845
where
2946
F: futures::Future + Send + 'static,
3047
F::Output: Send + 'static,
3148
{
3249
#[cfg(feature = "tokio")]
3350
{ TaskHandle(tokio::spawn(f)) }
3451
#[cfg(feature = "smol")]
35-
{ TaskHandle(smol::spawn(f)) }
52+
{ TaskHandle(smol::spawn(async move { Ok(f.await) } )) }
53+
3654
}
3755

3856
pub fn block_on<T>(future: impl Future<Output = T>) -> T {

0 commit comments

Comments
 (0)