@@ -37,6 +37,7 @@ use crate::util::ser::Writeable;
3737
3838use core:: fmt;
3939use core:: ops:: Deref ;
40+ use core:: sync:: atomic:: { AtomicBool , Ordering } ;
4041use crate :: io;
4142use crate :: sync:: Mutex ;
4243use crate :: prelude:: * ;
@@ -264,6 +265,7 @@ pub struct OnionMessenger<
264265 intercept_messages_for_offline_peers : bool ,
265266 pending_intercepted_msgs_events : Mutex < Vec < Event > > ,
266267 pending_peer_connected_events : Mutex < Vec < Event > > ,
268+ pending_events_processor : AtomicBool ,
267269}
268270
269271/// [`OnionMessage`]s buffered to be sent.
@@ -1018,6 +1020,28 @@ where
10181020 }
10191021}
10201022
1023+ macro_rules! drop_handled_events_and_abort { ( $self: expr, $res: expr, $offset: expr, $event_queue: expr) => {
1024+ // We want to make sure to cleanly abort upon event handling failure. To this end, we drop all
1025+ // successfully handled events from the given queue, reset the events processing flag, and
1026+ // return, to have the events eventually replayed upon next invocation.
1027+ {
1028+ let mut queue_lock = $event_queue. lock( ) . unwrap( ) ;
1029+
1030+ // We skip `$offset` result entries to reach the ones relevant for the given `$event_queue`.
1031+ let mut res_iter = $res. iter( ) . skip( $offset) ;
1032+
1033+ // Keep all events which previously error'd *or* any that have been added since we dropped
1034+ // the Mutex before.
1035+ queue_lock. retain( |_| res_iter. next( ) . map_or( true , |r| r. is_err( ) ) ) ;
1036+
1037+ if $res. iter( ) . any( |r| r. is_err( ) ) {
1038+ // We failed handling some events. Return to have them eventually replayed.
1039+ $self. pending_events_processor. store( false , Ordering :: Release ) ;
1040+ return ;
1041+ }
1042+ }
1043+ } }
1044+
10211045impl < ES : Deref , NS : Deref , L : Deref , NL : Deref , MR : Deref , OMH : Deref , APH : Deref , CMH : Deref >
10221046OnionMessenger < ES , NS , L , NL , MR , OMH , APH , CMH >
10231047where
@@ -1094,6 +1118,7 @@ where
10941118 intercept_messages_for_offline_peers,
10951119 pending_intercepted_msgs_events : Mutex :: new ( Vec :: new ( ) ) ,
10961120 pending_peer_connected_events : Mutex :: new ( Vec :: new ( ) ) ,
1121+ pending_events_processor : AtomicBool :: new ( false ) ,
10971122 }
10981123 }
10991124
@@ -1332,45 +1357,60 @@ where
13321357 pub async fn process_pending_events_async < Future : core:: future:: Future < Output = Result < ( ) , ReplayEvent > > + core:: marker:: Unpin , H : Fn ( Event ) -> Future > (
13331358 & self , handler : H
13341359 ) {
1335- let mut intercepted_msgs = Vec :: new ( ) ;
1336- let mut peer_connecteds = Vec :: new ( ) ;
1337- {
1338- let mut pending_intercepted_msgs_events =
1339- self . pending_intercepted_msgs_events . lock ( ) . unwrap ( ) ;
1340- let mut pending_peer_connected_events =
1341- self . pending_peer_connected_events . lock ( ) . unwrap ( ) ;
1342- core:: mem:: swap ( & mut * pending_intercepted_msgs_events, & mut intercepted_msgs) ;
1343- core:: mem:: swap ( & mut * pending_peer_connected_events, & mut peer_connecteds) ;
1360+ if self . pending_events_processor . compare_exchange ( false , true , Ordering :: Acquire , Ordering :: Relaxed ) . is_err ( ) {
1361+ return ;
13441362 }
13451363
1346- let mut futures = Vec :: with_capacity ( intercepted_msgs. len ( ) ) ;
1347- for ( node_id, recipient) in self . message_recipients . lock ( ) . unwrap ( ) . iter_mut ( ) {
1348- if let OnionMessageRecipient :: PendingConnection ( _, addresses, _) = recipient {
1349- if let Some ( addresses) = addresses. take ( ) {
1350- let future = ResultFuture :: Pending ( handler ( Event :: ConnectionNeeded { node_id : * node_id, addresses } ) ) ;
1351- futures. push ( future) ;
1364+ {
1365+ let intercepted_msgs = self . pending_intercepted_msgs_events . lock ( ) . unwrap ( ) . clone ( ) ;
1366+ let mut futures = Vec :: with_capacity ( intercepted_msgs. len ( ) ) ;
1367+ for ( node_id, recipient) in self . message_recipients . lock ( ) . unwrap ( ) . iter_mut ( ) {
1368+ if let OnionMessageRecipient :: PendingConnection ( _, addresses, _) = recipient {
1369+ if let Some ( addresses) = addresses. take ( ) {
1370+ let future = ResultFuture :: Pending ( handler ( Event :: ConnectionNeeded { node_id : * node_id, addresses } ) ) ;
1371+ futures. push ( future) ;
1372+ }
13521373 }
13531374 }
1354- }
13551375
1356- for ev in intercepted_msgs {
1357- if let Event :: OnionMessageIntercepted { .. } = ev { } else { debug_assert ! ( false ) ; }
1358- let future = ResultFuture :: Pending ( handler ( ev) ) ;
1359- futures. push ( future) ;
1360- }
1361- // Let the `OnionMessageIntercepted` events finish before moving on to peer_connecteds
1362- MultiResultFuturePoller :: new ( futures) . await ;
1376+ // The offset in the `futures` vec at which `intercepted_msgs` start. We don't bother
1377+ // replaying `ConnectionNeeded` events.
1378+ let intercepted_msgs_offset = futures. len ( ) ;
13631379
1364- if peer_connecteds. len ( ) <= 1 {
1365- for event in peer_connecteds { handler ( event) . await ; }
1366- } else {
1367- let mut futures = Vec :: new ( ) ;
1368- for event in peer_connecteds {
1369- let future = ResultFuture :: Pending ( handler ( event) ) ;
1380+ for ev in intercepted_msgs {
1381+ if let Event :: OnionMessageIntercepted { .. } = ev { } else { debug_assert ! ( false ) ; }
1382+ let future = ResultFuture :: Pending ( handler ( ev) ) ;
13701383 futures. push ( future) ;
13711384 }
1372- MultiResultFuturePoller :: new ( futures) . await ;
1385+ // Let the `OnionMessageIntercepted` events finish before moving on to peer_connecteds
1386+ let res = MultiResultFuturePoller :: new ( futures) . await ;
1387+ drop_handled_events_and_abort ! ( self , res, intercepted_msgs_offset, self . pending_intercepted_msgs_events) ;
13731388 }
1389+
1390+ {
1391+ let peer_connecteds = self . pending_peer_connected_events . lock ( ) . unwrap ( ) . clone ( ) ;
1392+ let num_peer_connecteds = peer_connecteds. len ( ) ;
1393+ if num_peer_connecteds <= 1 {
1394+ for event in peer_connecteds {
1395+ if handler ( event) . await . is_ok ( ) {
1396+ self . pending_peer_connected_events . lock ( ) . unwrap ( ) . drain ( ..num_peer_connecteds) ;
1397+ } else {
1398+ // We failed handling the event. Return to have it eventually replayed.
1399+ self . pending_events_processor . store ( false , Ordering :: Release ) ;
1400+ return ;
1401+ }
1402+ }
1403+ } else {
1404+ let mut futures = Vec :: new ( ) ;
1405+ for event in peer_connecteds {
1406+ let future = ResultFuture :: Pending ( handler ( event) ) ;
1407+ futures. push ( future) ;
1408+ }
1409+ let res = MultiResultFuturePoller :: new ( futures) . await ;
1410+ drop_handled_events_and_abort ! ( self , res, 0 , self . pending_peer_connected_events) ;
1411+ }
1412+ }
1413+ self . pending_events_processor . store ( false , Ordering :: Release ) ;
13741414 }
13751415}
13761416
@@ -1410,17 +1450,24 @@ where
14101450 CMH :: Target : CustomOnionMessageHandler ,
14111451{
14121452 fn process_pending_events < H : Deref > ( & self , handler : H ) where H :: Target : EventHandler {
1453+ if self . pending_events_processor . compare_exchange ( false , true , Ordering :: Acquire , Ordering :: Relaxed ) . is_err ( ) {
1454+ return ;
1455+ }
1456+
14131457 for ( node_id, recipient) in self . message_recipients . lock ( ) . unwrap ( ) . iter_mut ( ) {
14141458 if let OnionMessageRecipient :: PendingConnection ( _, addresses, _) = recipient {
14151459 if let Some ( addresses) = addresses. take ( ) {
14161460 let _ = handler. handle_event ( Event :: ConnectionNeeded { node_id : * node_id, addresses } ) ;
14171461 }
14181462 }
14191463 }
1420- let mut events = Vec :: new ( ) ;
1464+ let intercepted_msgs;
1465+ let peer_connecteds;
14211466 {
1422- let mut pending_intercepted_msgs_events = self . pending_intercepted_msgs_events . lock ( ) . unwrap ( ) ;
1467+ let pending_intercepted_msgs_events = self . pending_intercepted_msgs_events . lock ( ) . unwrap ( ) ;
1468+ intercepted_msgs = pending_intercepted_msgs_events. clone ( ) ;
14231469 let mut pending_peer_connected_events = self . pending_peer_connected_events . lock ( ) . unwrap ( ) ;
1470+ peer_connecteds = pending_peer_connected_events. clone ( ) ;
14241471 #[ cfg( debug_assertions) ] {
14251472 for ev in pending_intercepted_msgs_events. iter ( ) {
14261473 if let Event :: OnionMessageIntercepted { .. } = ev { } else { panic ! ( ) ; }
@@ -1429,13 +1476,16 @@ where
14291476 if let Event :: OnionMessagePeerConnected { .. } = ev { } else { panic ! ( ) ; }
14301477 }
14311478 }
1432- core:: mem:: swap ( & mut * pending_intercepted_msgs_events, & mut events) ;
1433- events. append ( & mut pending_peer_connected_events) ;
14341479 pending_peer_connected_events. shrink_to ( 10 ) ; // Limit total heap usage
14351480 }
1436- for ev in events {
1437- handler. handle_event ( ev) ;
1438- }
1481+
1482+ let res = intercepted_msgs. into_iter ( ) . map ( |ev| handler. handle_event ( ev) ) . collect :: < Vec < _ > > ( ) ;
1483+ drop_handled_events_and_abort ! ( self , res, 0 , self . pending_intercepted_msgs_events) ;
1484+
1485+ let res = peer_connecteds. into_iter ( ) . map ( |ev| handler. handle_event ( ev) ) . collect :: < Vec < _ > > ( ) ;
1486+ drop_handled_events_and_abort ! ( self , res, 0 , self . pending_peer_connected_events) ;
1487+
1488+ self . pending_events_processor . store ( false , Ordering :: Release ) ;
14391489 }
14401490}
14411491
0 commit comments