diff --git a/Cargo.toml b/Cargo.toml index ce49f2c..3ddb69d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "shvclient" description = "A Rust framework for Silicon Heaven RPC devices" license = "MIT" repository = "https://github.com/silicon-heaven/libshvclient-rs" -version = "4.2.0" +version = "5.0.0" edition = "2024" [lib] @@ -55,6 +55,8 @@ futures-rustls = "0.26.0" rustls-pemfile = "2.2.0" rustls-platform-verifier = "0.6.1" async-trait = "0.1.89" +futures-timer = "3.0.3" +event-listener = { version = "5.4.1", optional = true } # For local development # [patch.crates-io] @@ -65,7 +67,7 @@ async-trait = "0.1.89" default = [] tokio = ["dep:tokio", "dep:tokio-util"] smol = ["dep:smol"] -mocking = [] +mocking = ["dep:event-listener"] [lints.clippy] pedantic = {priority = -1, level = "warn"} diff --git a/src/mocking.rs b/src/mocking.rs index 9484de1..846b741 100644 --- a/src/mocking.rs +++ b/src/mocking.rs @@ -1,10 +1,11 @@ -use std::{error::Error, sync::Arc, time::Duration}; -use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}}; +use std::{error::Error, sync::{Arc, RwLock}, time::Duration}; +use event_listener::Event; +use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, future::{Either, select}}; +use futures_timer::Delay; use log::debug; use shvproto::RpcValue; use shvrpc::{RpcMessage, RpcMessageMetaTags, rpc::ShvRI, rpcmessage::{RpcError, RpcErrorCode}}; -use tokio::{sync::{Notify, RwLock}, task::JoinHandle}; -use crate::{ConnectionCommand, ConnectionEvent}; +use crate::{ConnectionCommand, ConnectionEvent, runtime::TaskHandle}; const TIMEOUT_DURATION: Duration = Duration::from_secs(5); @@ -69,11 +70,11 @@ where } pub struct TestApp { - app: JoinHandle>>, - msg_task: JoinHandle<()>, + app: TaskHandle>>, + msg_task: TaskHandle<()>, conn_evt_tx: UnboundedSender, pending_msgs: Arc>>, - notifier: Arc, + notifier: Arc, } impl TestApp { @@ -95,17 +96,18 @@ impl TestApp { F: Fn(&RpcMessage) -> bool { loop { - let waiter = self.notifier.notified(); + let event_listener = self.notifier.listen(); - let mut pending = self.pending_msgs.write().await; - let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos)); + { + let mut pending = self.pending_msgs.write().expect("pending_msgs write should succeed"); + let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos)); - if let Some(matched_msg) = matched_msg { - return matched_msg; + if let Some(matched_msg) = matched_msg { + return matched_msg; + } } - drop(pending); - tokio::time::timeout(TIMEOUT_DURATION, waiter).await.unwrap_or_else(|_err| panic!("Timed out while waiting for message")); + timeout(TIMEOUT_DURATION, event_listener).await.unwrap_or_else(|err| panic!("{err}")); } } @@ -175,18 +177,18 @@ impl TestApp { .respond(Ok("true")); } - pub fn new(app_maker: impl FnOnce(UnboundedReceiver) -> JoinHandle>>) -> Self { + pub fn new(app_maker: impl FnOnce(UnboundedReceiver) -> TaskHandle>>) -> Self { let (conn_evt_tx, conn_evt_rx) = futures::channel::mpsc::unbounded::(); let app = app_maker(conn_evt_rx); let (conn_cmd_sender_in, mut conn_cmd_receiver_in) = unbounded(); let pending_msgs = Arc::new(RwLock::new(Vec::new())); - let notifier = Arc::new(Notify::new()); + let notifier = Arc::new(Event::new()); let msg_task = { let pending_msgs = pending_msgs.clone(); let notifier = notifier.clone(); - tokio::spawn(async move { + crate::runtime::spawn_task(async move { while let Some(ConnectionCommand::SendMessage(rpc_message)) = conn_cmd_receiver_in.next().await { let shv_path = rpc_message.shv_path().unwrap_or_default(); let method = rpc_message.method().unwrap_or_default(); @@ -200,9 +202,9 @@ impl TestApp { } { - pending_msgs.write().await.push(rpc_message); + pending_msgs.write().expect("pending_msgs write should succeed").push(rpc_message); } - notifier.notify_waiters(); + notifier.notify(usize::MAX); } }) }; @@ -218,30 +220,33 @@ impl TestApp { pub async fn wait_until_finished(self) -> shvrpc::Result<()> { // Wait for a bit to ensure silence from the app. - tokio::time::sleep(Duration::from_millis(500)).await; + futures_timer::Delay::new(Duration::from_millis(500)).await; self.conn_evt_tx.close_channel(); - let mut pending_msgs = self.pending_msgs.write().await; - // We'll let unsubscribe calls slide. - pending_msgs.retain(|msg| - !msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe") - ); - for rpc_message in pending_msgs.iter() { - let shv_path = rpc_message.shv_path().unwrap_or_default(); - let method = rpc_message.method().unwrap_or_default(); - let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null); - let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} ")); - if rpc_message.is_response() { - let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported")); - debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}"); - } else if method.is_empty() && shv_path.is_empty() { - debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}"); - } else { - debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}"); - } - }; - if !pending_msgs.is_empty() { - return Err("There were unexpected messages received from the app.".into()); + { + let mut pending_msgs = self.pending_msgs.write().expect("pending_msgs write should succeed"); + // We'll let unsubscribe calls slide. + pending_msgs.retain(|msg| + !msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe") + ); + for rpc_message in pending_msgs.iter() { + let shv_path = rpc_message.shv_path().unwrap_or_default(); + let method = rpc_message.method().unwrap_or_default(); + let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null); + let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} ")); + if rpc_message.is_response() { + let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported")); + debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}"); + } else if method.is_empty() && shv_path.is_empty() { + debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}"); + } else { + debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}"); + } + }; + + if !pending_msgs.is_empty() { + return Err("There were unexpected messages received from the app.".into()); + } } let end = async move { @@ -251,7 +256,20 @@ impl TestApp { Ok(()) }; - tokio::time::timeout(TIMEOUT_DURATION, end).await? + timeout(TIMEOUT_DURATION, end).await? } } +pub async fn timeout(dur: Duration, fut: F) -> shvrpc::Result +where + F: Future, +{ + futures::pin_mut!(fut); + let timeout = Delay::new(dur); + futures::pin_mut!(timeout); + + match select(fut, timeout).await { + Either::Left((val, _)) => Ok(val), + Either::Right(_) => Err("Timed out while waiting for a future".into()), + } +} diff --git a/src/runtime.rs b/src/runtime.rs index 9d11476..36db66a 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,12 +1,29 @@ + #[must_use = "Task has to be used. If you want to detach the task, call .detach() on it."] -pub struct TaskHandle( +pub struct TaskHandle( +// pub struct TaskHandle( #[cfg(feature = "tokio")] - pub tokio::task::JoinHandle, + pub tokio::task::JoinHandle, #[cfg(feature = "smol")] - pub smol::Task, + // The error type is dummy as smol::Task future resolves to the result right away, + // but we want to keep the same API with tokio JoinHandle, which returns a Result. + pub smol::Task>, ); -impl TaskHandle { +impl Future for TaskHandle { + #[cfg(feature = "tokio")] + type Output = Result; + + #[cfg(feature = "smol")] + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + std::pin::Pin::new(&mut self.0).poll(cx) + } +} + +// impl TaskHandle { +impl TaskHandle { #[cfg(feature = "tokio")] #[expect(clippy::unused_async, reason = "We want the same API as with smol")] pub async fn cancel(self) { @@ -24,7 +41,7 @@ impl TaskHandle { } } -pub fn spawn_task(f: F) -> TaskHandle +pub fn spawn_task(f: F) -> TaskHandle where F: futures::Future + Send + 'static, F::Output: Send + 'static, @@ -32,7 +49,8 @@ where #[cfg(feature = "tokio")] { TaskHandle(tokio::spawn(f)) } #[cfg(feature = "smol")] - { TaskHandle(smol::spawn(f)) } + { TaskHandle(smol::spawn(async move { Ok(f.await) } )) } + } pub fn block_on(future: impl Future) -> T {