1- use std:: { error:: Error , sync:: Arc , time:: Duration } ;
2- use futures:: { StreamExt , channel:: mpsc:: { UnboundedReceiver , UnboundedSender , unbounded} } ;
1+ use std:: { error:: Error , sync:: { Arc , RwLock } , time:: Duration } ;
2+ use event_listener:: Event ;
3+ use futures:: { StreamExt , channel:: mpsc:: { UnboundedReceiver , UnboundedSender , unbounded} , future:: { Either , select} } ;
4+ use futures_timer:: Delay ;
35use log:: debug;
46use shvproto:: RpcValue ;
57use shvrpc:: { RpcMessage , RpcMessageMetaTags , rpc:: ShvRI , rpcmessage:: { RpcError , RpcErrorCode } } ;
6- use tokio:: { sync:: { Notify , RwLock } , task:: JoinHandle } ;
7- use crate :: { ConnectionCommand , ConnectionEvent } ;
8+ use crate :: { ConnectionCommand , ConnectionEvent , runtime:: TaskHandle } ;
89
910const TIMEOUT_DURATION : Duration = Duration :: from_secs ( 5 ) ;
1011
@@ -69,11 +70,11 @@ where
6970}
7071
7172pub struct TestApp {
72- app : JoinHandle < Result < ( ) , Box < dyn Error + Send + Sync > > > ,
73- msg_task : JoinHandle < ( ) > ,
73+ app : TaskHandle < Result < ( ) , Box < dyn Error + Send + Sync > > > ,
74+ msg_task : TaskHandle < ( ) > ,
7475 conn_evt_tx : UnboundedSender < ConnectionEvent > ,
7576 pending_msgs : Arc < RwLock < Vec < RpcMessage > > > ,
76- notifier : Arc < Notify > ,
77+ notifier : Arc < Event > ,
7778}
7879
7980impl TestApp {
@@ -95,17 +96,18 @@ impl TestApp {
9596 F : Fn ( & RpcMessage ) -> bool
9697 {
9798 loop {
98- let waiter = self . notifier . notified ( ) ;
99+ let event_listener = self . notifier . listen ( ) ;
99100
100- let mut pending = self . pending_msgs . write ( ) . await ;
101- let matched_msg = pending. iter ( ) . position ( & matcher) . map ( |pos| pending. remove ( pos) ) ;
101+ {
102+ let mut pending = self . pending_msgs . write ( ) . expect ( "pending_msgs write should succeed" ) ;
103+ let matched_msg = pending. iter ( ) . position ( & matcher) . map ( |pos| pending. remove ( pos) ) ;
102104
103- if let Some ( matched_msg) = matched_msg {
104- return matched_msg;
105+ if let Some ( matched_msg) = matched_msg {
106+ return matched_msg;
107+ }
105108 }
106109
107- drop ( pending) ;
108- tokio:: time:: timeout ( TIMEOUT_DURATION , waiter) . await . unwrap_or_else ( |_err| panic ! ( "Timed out while waiting for message" ) ) ;
110+ timeout ( TIMEOUT_DURATION , event_listener) . await . unwrap_or_else ( |err| panic ! ( "{err}" ) ) ;
109111 }
110112 }
111113
@@ -175,18 +177,18 @@ impl TestApp {
175177 . respond ( Ok ( "true" ) ) ;
176178 }
177179
178- pub fn new ( app_maker : impl FnOnce ( UnboundedReceiver < ConnectionEvent > ) -> JoinHandle < Result < ( ) , Box < dyn Error + Send + Sync > > > ) -> Self {
180+ pub fn new ( app_maker : impl FnOnce ( UnboundedReceiver < ConnectionEvent > ) -> TaskHandle < Result < ( ) , Box < dyn Error + Send + Sync > > > ) -> Self {
179181 let ( conn_evt_tx, conn_evt_rx) = futures:: channel:: mpsc:: unbounded :: < ConnectionEvent > ( ) ;
180182 let app = app_maker ( conn_evt_rx) ;
181183
182184 let ( conn_cmd_sender_in, mut conn_cmd_receiver_in) = unbounded ( ) ;
183185 let pending_msgs = Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ;
184- let notifier = Arc :: new ( Notify :: new ( ) ) ;
186+ let notifier = Arc :: new ( Event :: new ( ) ) ;
185187
186188 let msg_task = {
187189 let pending_msgs = pending_msgs. clone ( ) ;
188190 let notifier = notifier. clone ( ) ;
189- tokio :: spawn ( async move {
191+ crate :: runtime :: spawn_task ( async move {
190192 while let Some ( ConnectionCommand :: SendMessage ( rpc_message) ) = conn_cmd_receiver_in. next ( ) . await {
191193 let shv_path = rpc_message. shv_path ( ) . unwrap_or_default ( ) ;
192194 let method = rpc_message. method ( ) . unwrap_or_default ( ) ;
@@ -200,9 +202,9 @@ impl TestApp {
200202 }
201203
202204 {
203- pending_msgs. write ( ) . await . push ( rpc_message) ;
205+ pending_msgs. write ( ) . expect ( "pending_msgs write should succeed" ) . push ( rpc_message) ;
204206 }
205- notifier. notify_waiters ( ) ;
207+ notifier. notify ( usize :: MAX ) ;
206208 }
207209 } )
208210 } ;
@@ -218,30 +220,33 @@ impl TestApp {
218220
219221 pub async fn wait_until_finished ( self ) -> shvrpc:: Result < ( ) > {
220222 // Wait for a bit to ensure silence from the app.
221- tokio :: time :: sleep ( Duration :: from_millis ( 500 ) ) . await ;
223+ futures_timer :: Delay :: new ( Duration :: from_millis ( 500 ) ) . await ;
222224 self . conn_evt_tx . close_channel ( ) ;
223- let mut pending_msgs = self . pending_msgs . write ( ) . await ;
224- // We'll let unsubscribe calls slide.
225- pending_msgs. retain ( |msg|
226- !msg. is_request ( ) || msg. shv_path ( ) != Some ( ".broker/currentClient" ) || msg. method ( ) != Some ( "unsubscribe" )
227- ) ;
228- for rpc_message in pending_msgs. iter ( ) {
229- let shv_path = rpc_message. shv_path ( ) . unwrap_or_default ( ) ;
230- let method = rpc_message. method ( ) . unwrap_or_default ( ) ;
231- let param = rpc_message. param ( ) . cloned ( ) . unwrap_or_else ( RpcValue :: null) ;
232- let msg_prefix = rpc_message. request_id ( ) . map_or_else ( String :: new, |rq_id| format ! ( "rq_id:{rq_id} " ) ) ;
233- if rpc_message. is_response ( ) {
234- let result = rpc_message. response ( ) . map ( |resp| resp. success ( ) . expect ( "Only success responses are supported" ) ) ;
235- debug ! ( target: "test-driver" , "UNEXPECTED <== {msg_prefix} -> {result:?}" ) ;
236- } else if method. is_empty ( ) && shv_path. is_empty ( ) {
237- debug ! ( target: "test-driver" , "UNEXPECTED <== {msg_prefix}{param}" ) ;
238- } else {
239- debug ! ( target: "test-driver" , "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}" ) ;
240- }
241- } ;
242225
243- if !pending_msgs. is_empty ( ) {
244- return Err ( "There were unexpected messages received from the app." . into ( ) ) ;
226+ {
227+ let mut pending_msgs = self . pending_msgs . write ( ) . expect ( "pending_msgs write should succeed" ) ;
228+ // We'll let unsubscribe calls slide.
229+ pending_msgs. retain ( |msg|
230+ !msg. is_request ( ) || msg. shv_path ( ) != Some ( ".broker/currentClient" ) || msg. method ( ) != Some ( "unsubscribe" )
231+ ) ;
232+ for rpc_message in pending_msgs. iter ( ) {
233+ let shv_path = rpc_message. shv_path ( ) . unwrap_or_default ( ) ;
234+ let method = rpc_message. method ( ) . unwrap_or_default ( ) ;
235+ let param = rpc_message. param ( ) . cloned ( ) . unwrap_or_else ( RpcValue :: null) ;
236+ let msg_prefix = rpc_message. request_id ( ) . map_or_else ( String :: new, |rq_id| format ! ( "rq_id:{rq_id} " ) ) ;
237+ if rpc_message. is_response ( ) {
238+ let result = rpc_message. response ( ) . map ( |resp| resp. success ( ) . expect ( "Only success responses are supported" ) ) ;
239+ debug ! ( target: "test-driver" , "UNEXPECTED <== {msg_prefix} -> {result:?}" ) ;
240+ } else if method. is_empty ( ) && shv_path. is_empty ( ) {
241+ debug ! ( target: "test-driver" , "UNEXPECTED <== {msg_prefix}{param}" ) ;
242+ } else {
243+ debug ! ( target: "test-driver" , "UNEXPECTED <== {msg_prefix}{shv_path}:{method}, param: {param}" ) ;
244+ }
245+ } ;
246+
247+ if !pending_msgs. is_empty ( ) {
248+ return Err ( "There were unexpected messages received from the app." . into ( ) ) ;
249+ }
245250 }
246251
247252 let end = async move {
@@ -251,7 +256,20 @@ impl TestApp {
251256 Ok ( ( ) )
252257 } ;
253258
254- tokio :: time :: timeout ( TIMEOUT_DURATION , end) . await ?
259+ timeout ( TIMEOUT_DURATION , end) . await ?
255260 }
256261}
257262
263+ pub async fn timeout < F , T > ( dur : Duration , fut : F ) -> shvrpc:: Result < T >
264+ where
265+ F : Future < Output = T > ,
266+ {
267+ futures:: pin_mut!( fut) ;
268+ let timeout = Delay :: new ( dur) ;
269+ futures:: pin_mut!( timeout) ;
270+
271+ match select ( fut, timeout) . await {
272+ Either :: Left ( ( val, _) ) => Ok ( val) ,
273+ Either :: Right ( _) => Err ( "Timed out while waiting for a future" . into ( ) ) ,
274+ }
275+ }
0 commit comments