Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "shvclient"
description = "A Rust framework for Silicon Heaven RPC devices"
license = "MIT"
repository = "https://github.com/silicon-heaven/libshvclient-rs"
version = "4.2.0"
version = "5.0.0"
edition = "2024"

[lib]
Expand Down Expand Up @@ -55,6 +55,8 @@ futures-rustls = "0.26.0"
rustls-pemfile = "2.2.0"
rustls-platform-verifier = "0.6.1"
async-trait = "0.1.89"
futures-timer = "3.0.3"
event-listener = { version = "5.4.1", optional = true }

# For local development
# [patch.crates-io]
Expand All @@ -65,7 +67,7 @@ async-trait = "0.1.89"
default = []
tokio = ["dep:tokio", "dep:tokio-util"]
smol = ["dep:smol"]
mocking = []
mocking = ["dep:event-listener"]

[lints.clippy]
pedantic = {priority = -1, level = "warn"}
Expand Down
102 changes: 60 additions & 42 deletions src/mocking.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{error::Error, sync::Arc, time::Duration};
use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}};
use std::{error::Error, sync::{Arc, RwLock}, time::Duration};
use event_listener::Event;
use futures::{StreamExt, channel::mpsc::{UnboundedReceiver, UnboundedSender, unbounded}, future::{Either, select}};
use futures_timer::Delay;
use log::debug;
use shvproto::RpcValue;
use shvrpc::{RpcMessage, RpcMessageMetaTags, rpc::ShvRI, rpcmessage::{RpcError, RpcErrorCode}};
use tokio::{sync::{Notify, RwLock}, task::JoinHandle};
use crate::{ConnectionCommand, ConnectionEvent};
use crate::{ConnectionCommand, ConnectionEvent, runtime::TaskHandle};

const TIMEOUT_DURATION: Duration = Duration::from_secs(5);

Expand Down Expand Up @@ -69,11 +70,11 @@ where
}

pub struct TestApp {
app: JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>,
msg_task: JoinHandle<()>,
app: TaskHandle<Result<(), Box<dyn Error + Send + Sync>>>,
msg_task: TaskHandle<()>,
conn_evt_tx: UnboundedSender<ConnectionEvent>,
pending_msgs: Arc<RwLock<Vec<RpcMessage>>>,
notifier: Arc<Notify>,
notifier: Arc<Event>,
}

impl TestApp {
Expand All @@ -95,17 +96,18 @@ impl TestApp {
F: Fn(&RpcMessage) -> bool
{
loop {
let waiter = self.notifier.notified();
let event_listener = self.notifier.listen();

let mut pending = self.pending_msgs.write().await;
let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos));
{
let mut pending = self.pending_msgs.write().expect("pending_msgs write should succeed");
let matched_msg = pending.iter().position(&matcher).map(|pos| pending.remove(pos));

if let Some(matched_msg) = matched_msg {
return matched_msg;
if let Some(matched_msg) = matched_msg {
return matched_msg;
}
}

drop(pending);
tokio::time::timeout(TIMEOUT_DURATION, waiter).await.unwrap_or_else(|_err| panic!("Timed out while waiting for message"));
timeout(TIMEOUT_DURATION, event_listener).await.unwrap_or_else(|err| panic!("{err}"));
}
}

Expand Down Expand Up @@ -175,18 +177,18 @@ impl TestApp {
.respond(Ok("true"));
}

pub fn new(app_maker: impl FnOnce(UnboundedReceiver<ConnectionEvent>) -> JoinHandle<Result<(), Box<dyn Error + Send + Sync>>>) -> Self {
pub fn new(app_maker: impl FnOnce(UnboundedReceiver<ConnectionEvent>) -> TaskHandle<Result<(), Box<dyn Error + Send + Sync>>>) -> Self {
let (conn_evt_tx, conn_evt_rx) = futures::channel::mpsc::unbounded::<ConnectionEvent>();
let app = app_maker(conn_evt_rx);

let (conn_cmd_sender_in, mut conn_cmd_receiver_in) = unbounded();
let pending_msgs = Arc::new(RwLock::new(Vec::new()));
let notifier = Arc::new(Notify::new());
let notifier = Arc::new(Event::new());

let msg_task = {
let pending_msgs = pending_msgs.clone();
let notifier = notifier.clone();
tokio::spawn(async move {
crate::runtime::spawn_task(async move {
while let Some(ConnectionCommand::SendMessage(rpc_message)) = conn_cmd_receiver_in.next().await {
let shv_path = rpc_message.shv_path().unwrap_or_default();
let method = rpc_message.method().unwrap_or_default();
Expand All @@ -200,9 +202,9 @@ impl TestApp {
}

{
pending_msgs.write().await.push(rpc_message);
pending_msgs.write().expect("pending_msgs write should succeed").push(rpc_message);
}
notifier.notify_waiters();
notifier.notify(usize::MAX);
}
})
};
Expand All @@ -218,30 +220,33 @@ impl TestApp {

pub async fn wait_until_finished(self) -> shvrpc::Result<()> {
// Wait for a bit to ensure silence from the app.
tokio::time::sleep(Duration::from_millis(500)).await;
futures_timer::Delay::new(Duration::from_millis(500)).await;
self.conn_evt_tx.close_channel();
let mut pending_msgs = self.pending_msgs.write().await;
// We'll let unsubscribe calls slide.
pending_msgs.retain(|msg|
!msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe")
);
for rpc_message in pending_msgs.iter() {
let shv_path = rpc_message.shv_path().unwrap_or_default();
let method = rpc_message.method().unwrap_or_default();
let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);
let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} "));
if rpc_message.is_response() {
let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported"));
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}");
} else if method.is_empty() && shv_path.is_empty() {
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}");
} else {
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}");
}
};

if !pending_msgs.is_empty() {
return Err("There were unexpected messages received from the app.".into());
{
let mut pending_msgs = self.pending_msgs.write().expect("pending_msgs write should succeed");
// We'll let unsubscribe calls slide.
pending_msgs.retain(|msg|
!msg.is_request() || msg.shv_path() != Some(".broker/currentClient") || msg.method() != Some("unsubscribe")
);
for rpc_message in pending_msgs.iter() {
let shv_path = rpc_message.shv_path().unwrap_or_default();
let method = rpc_message.method().unwrap_or_default();
let param = rpc_message.param().cloned().unwrap_or_else(RpcValue::null);
let msg_prefix = rpc_message.request_id().map_or_else(String::new, |rq_id| format!("rq_id:{rq_id} "));
if rpc_message.is_response() {
let result = rpc_message.response().map(|resp| resp.success().expect("Only success responses are supported"));
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix} -> {result:?}");
} else if method.is_empty() && shv_path.is_empty() {
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{param}");
} else {
debug!(target: "test-driver", "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}");
}
};

if !pending_msgs.is_empty() {
return Err("There were unexpected messages received from the app.".into());
}
}

let end = async move {
Expand All @@ -251,7 +256,20 @@ impl TestApp {
Ok(())
};

tokio::time::timeout(TIMEOUT_DURATION, end).await?
timeout(TIMEOUT_DURATION, end).await?
}
}

pub async fn timeout<F, T>(dur: Duration, fut: F) -> shvrpc::Result<T>
where
F: Future<Output = T>,
{
futures::pin_mut!(fut);
let timeout = Delay::new(dur);
futures::pin_mut!(timeout);

match select(fut, timeout).await {
Either::Left((val, _)) => Ok(val),
Either::Right(_) => Err("Timed out while waiting for a future".into()),
}
}
30 changes: 24 additions & 6 deletions src/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@

#[must_use = "Task has to be used. If you want to detach the task, call .detach() on it."]
pub struct TaskHandle<F: futures::Future + Send + 'static>(
pub struct TaskHandle<O: Send + 'static>(
// pub struct TaskHandle<F: futures::Future + Send + 'static>(
#[cfg(feature = "tokio")]
pub tokio::task::JoinHandle<F::Output>,
pub tokio::task::JoinHandle<O>,
#[cfg(feature = "smol")]
pub smol::Task<F::Output>,
// The error type is dummy as smol::Task future resolves to the result right away,
// but we want to keep the same API with tokio JoinHandle, which returns a Result.
pub smol::Task<Result<O, std::convert::Infallible>>,
);

impl<F: futures::Future + Send + 'static> TaskHandle<F> {
impl<O: Send + 'static> Future for TaskHandle<O> {
#[cfg(feature = "tokio")]
type Output = Result<O, tokio::task::JoinError>;

#[cfg(feature = "smol")]
type Output = Result<O, std::convert::Infallible>;

fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
std::pin::Pin::new(&mut self.0).poll(cx)
}
}

// impl<F: futures::Future + Send + 'static> TaskHandle<F> {
impl<O: Send + 'static> TaskHandle<O> {
#[cfg(feature = "tokio")]
#[expect(clippy::unused_async, reason = "We want the same API as with smol")]
pub async fn cancel(self) {
Expand All @@ -24,15 +41,16 @@ impl<F: futures::Future + Send + 'static> TaskHandle<F> {
}
}

pub fn spawn_task<F>(f: F) -> TaskHandle<F>
pub fn spawn_task<F>(f: F) -> TaskHandle<F::Output>
where
F: futures::Future + Send + 'static,
F::Output: Send + 'static,
{
#[cfg(feature = "tokio")]
{ TaskHandle(tokio::spawn(f)) }
#[cfg(feature = "smol")]
{ TaskHandle(smol::spawn(f)) }
{ TaskHandle(smol::spawn(async move { Ok(f.await) } )) }

}

pub fn block_on<T>(future: impl Future<Output = T>) -> T {
Expand Down
Loading