@@ -2,7 +2,8 @@ use std::collections::HashMap;
22use std:: convert:: TryInto ;
33use std:: io:: { BufRead , BufReader , Write } ;
44use std:: net:: { Shutdown , SocketAddr , TcpListener , TcpStream } ;
5- use std:: sync:: mpsc:: { Sender , SyncSender , TrySendError } ;
5+ use std:: sync:: atomic:: AtomicBool ;
6+ use std:: sync:: mpsc:: { Receiver , Sender } ;
67use std:: sync:: { Arc , Mutex } ;
78use std:: thread;
89
@@ -101,6 +102,7 @@ struct Connection {
101102 chan : SyncChannel < Message > ,
102103 stats : Arc < Stats > ,
103104 txs_limit : usize ,
105+ die_please : Option < Receiver < ( ) > > ,
104106 #[ cfg( feature = "electrum-discovery" ) ]
105107 discovery : Option < Arc < DiscoveryManager > > ,
106108}
@@ -112,6 +114,7 @@ impl Connection {
112114 addr : SocketAddr ,
113115 stats : Arc < Stats > ,
114116 txs_limit : usize ,
117+ die_please : Receiver < ( ) > ,
115118 #[ cfg( feature = "electrum-discovery" ) ] discovery : Option < Arc < DiscoveryManager > > ,
116119 ) -> Connection {
117120 Connection {
@@ -123,6 +126,7 @@ impl Connection {
123126 chan : SyncChannel :: new ( 10 ) ,
124127 stats,
125128 txs_limit,
129+ die_please : Some ( die_please) ,
126130 #[ cfg( feature = "electrum-discovery" ) ]
127131 discovery,
128132 }
@@ -501,38 +505,46 @@ impl Connection {
501505 Ok ( ( ) )
502506 }
503507
504- fn handle_replies ( & mut self ) -> Result < ( ) > {
508+ fn handle_replies ( & mut self , shutdown : crossbeam_channel :: Receiver < ( ) > ) -> Result < ( ) > {
505509 let empty_params = json ! ( [ ] ) ;
506510 loop {
507- let msg = self . chan . receiver ( ) . recv ( ) . chain_err ( || "channel closed" ) ?;
508- trace ! ( "RPC {:?}" , msg) ;
509- match msg {
510- Message :: Request ( line) => {
511- let cmd: Value = from_str ( & line) . chain_err ( || "invalid JSON format" ) ?;
512- let reply = match (
513- cmd. get ( "method" ) ,
514- cmd. get ( "params" ) . unwrap_or ( & empty_params) ,
515- cmd. get ( "id" ) ,
516- ) {
517- ( Some ( Value :: String ( method) ) , Value :: Array ( params) , Some ( id) ) => {
518- self . handle_command ( method, params, id) ?
511+ crossbeam_channel:: select! {
512+ recv( self . chan. receiver( ) ) -> msg => {
513+ let msg = msg. chain_err( || "channel closed" ) ?;
514+ trace!( "RPC {:?}" , msg) ;
515+ match msg {
516+ Message :: Request ( line) => {
517+ let cmd: Value = from_str( & line) . chain_err( || "invalid JSON format" ) ?;
518+ let reply = match (
519+ cmd. get( "method" ) ,
520+ cmd. get( "params" ) . unwrap_or( & empty_params) ,
521+ cmd. get( "id" ) ,
522+ ) {
523+ ( Some ( Value :: String ( method) ) , Value :: Array ( params) , Some ( id) ) => {
524+ self . handle_command( method, params, id) ?
525+ }
526+ _ => bail!( "invalid command: {}" , cmd) ,
527+ } ;
528+ self . send_values( & [ reply] ) ?
519529 }
520- _ => bail ! ( "invalid command: {}" , cmd) ,
521- } ;
522- self . send_values ( & [ reply] ) ?
523- }
524- Message :: PeriodicUpdate => {
525- let values = self
526- . update_subscriptions ( )
527- . chain_err ( || "failed to update subscriptions" ) ?;
528- self . send_values ( & values) ?
530+ Message :: PeriodicUpdate => {
531+ let values = self
532+ . update_subscriptions( )
533+ . chain_err( || "failed to update subscriptions" ) ?;
534+ self . send_values( & values) ?
535+ }
536+ Message :: Done => return Ok ( ( ) ) ,
537+ }
529538 }
530- Message :: Done => return Ok ( ( ) ) ,
539+ recv ( shutdown ) -> _ => return Ok ( ( ) ) ,
531540 }
532541 }
533542 }
534543
535- fn handle_requests ( mut reader : BufReader < TcpStream > , tx : SyncSender < Message > ) -> Result < ( ) > {
544+ fn handle_requests (
545+ mut reader : BufReader < TcpStream > ,
546+ tx : crossbeam_channel:: Sender < Message > ,
547+ ) -> Result < ( ) > {
536548 loop {
537549 let mut line = Vec :: < u8 > :: new ( ) ;
538550 reader
@@ -564,8 +576,18 @@ impl Connection {
564576 self . stats . clients . inc ( ) ;
565577 let reader = BufReader :: new ( self . stream . try_clone ( ) . expect ( "failed to clone TcpStream" ) ) ;
566578 let tx = self . chan . sender ( ) ;
579+
580+ let stream = self . stream . try_clone ( ) . expect ( "failed to clone TcpStream" ) ;
581+ let die_please = self . die_please . take ( ) . unwrap ( ) ;
582+ let ( reply_killer, reply_receiver) = crossbeam_channel:: unbounded ( ) ;
583+ spawn_thread ( "properly-die" , move || {
584+ let _ = die_please. recv ( ) ;
585+ let _ = stream. shutdown ( Shutdown :: Both ) ;
586+ let _ = reply_killer. send ( ( ) ) ;
587+ } ) ;
588+
567589 let child = spawn_thread ( "reader" , || Connection :: handle_requests ( reader, tx) ) ;
568- if let Err ( e) = self . handle_replies ( ) {
590+ if let Err ( e) = self . handle_replies ( reply_receiver ) {
569591 error ! (
570592 "[{}] connection handling failed: {}" ,
571593 self . addr,
@@ -631,30 +653,38 @@ struct Stats {
631653impl RPC {
632654 fn start_notifier (
633655 notification : Channel < Notification > ,
634- senders : Arc < Mutex < Vec < SyncSender < Message > > > > ,
656+ senders : Arc < Mutex < Vec < crossbeam_channel :: Sender < Message > > > > ,
635657 acceptor : Sender < Option < ( TcpStream , SocketAddr ) > > ,
658+ acceptor_shutdown : Sender < ( ) > ,
636659 ) {
637660 spawn_thread ( "notification" , move || {
638661 for msg in notification. receiver ( ) . iter ( ) {
639662 let mut senders = senders. lock ( ) . unwrap ( ) ;
640663 match msg {
641664 Notification :: Periodic => {
642665 for sender in senders. split_off ( 0 ) {
643- if let Err ( TrySendError :: Disconnected ( _) ) =
666+ if let Err ( crossbeam_channel :: TrySendError :: Disconnected ( _) ) =
644667 sender. try_send ( Message :: PeriodicUpdate )
645668 {
646669 continue ;
647670 }
648671 senders. push ( sender) ;
649672 }
650673 }
651- Notification :: Exit => acceptor. send ( None ) . unwrap ( ) , // mark acceptor as done
674+ Notification :: Exit => {
675+ acceptor_shutdown. send ( ( ) ) . unwrap ( ) ; // Stop the acceptor itself
676+ acceptor. send ( None ) . unwrap ( ) ; // mark acceptor as done
677+ break ;
678+ }
652679 }
653680 }
654681 } ) ;
655682 }
656683
657- fn start_acceptor ( addr : SocketAddr ) -> Channel < Option < ( TcpStream , SocketAddr ) > > {
684+ fn start_acceptor (
685+ addr : SocketAddr ,
686+ shutdown_channel : Channel < ( ) > ,
687+ ) -> Channel < Option < ( TcpStream , SocketAddr ) > > {
658688 let chan = Channel :: unbounded ( ) ;
659689 let acceptor = chan. sender ( ) ;
660690 spawn_thread ( "acceptor" , move || {
@@ -664,10 +694,29 @@ impl RPC {
664694 . set_nonblocking ( false )
665695 . expect ( "cannot set nonblocking to false" ) ;
666696 let listener = TcpListener :: from ( socket) ;
697+ let local_addr = listener. local_addr ( ) . unwrap ( ) ;
698+ let shutdown_bool = Arc :: new ( AtomicBool :: new ( false ) ) ;
699+
700+ {
701+ let shutdown_bool = Arc :: clone ( & shutdown_bool) ;
702+ crate :: util:: spawn_thread ( "shutdown-acceptor" , move || {
703+ // Block until shutdown is sent.
704+ let _ = shutdown_channel. receiver ( ) . recv ( ) ;
705+ // Store the bool so after the next accept it will break the loop
706+ shutdown_bool. store ( true , std:: sync:: atomic:: Ordering :: Release ) ;
707+ // Connect to the socket to cause it to unblock
708+ let _ = TcpStream :: connect ( local_addr) ;
709+ } ) ;
710+ }
667711
668712 info ! ( "Electrum RPC server running on {}" , addr) ;
669713 loop {
670714 let ( stream, addr) = listener. accept ( ) . expect ( "accept failed" ) ;
715+
716+ if shutdown_bool. load ( std:: sync:: atomic:: Ordering :: Acquire ) {
717+ break ;
718+ }
719+
671720 stream
672721 . set_nonblocking ( false )
673722 . expect ( "failed to set connection as blocking" ) ;
@@ -724,10 +773,19 @@ impl RPC {
724773 RPC {
725774 notification : notification. sender ( ) ,
726775 server : Some ( spawn_thread ( "rpc" , move || {
727- let senders = Arc :: new ( Mutex :: new ( Vec :: < SyncSender < Message > > :: new ( ) ) ) ;
728-
729- let acceptor = RPC :: start_acceptor ( rpc_addr) ;
730- RPC :: start_notifier ( notification, senders. clone ( ) , acceptor. sender ( ) ) ;
776+ let senders =
777+ Arc :: new ( Mutex :: new ( Vec :: < crossbeam_channel:: Sender < Message > > :: new ( ) ) ) ;
778+ let killers = Arc :: new ( Mutex :: new ( Vec :: < Sender < ( ) > > :: new ( ) ) ) ;
779+
780+ let acceptor_shutdown = Channel :: unbounded ( ) ;
781+ let acceptor_shutdown_sender = acceptor_shutdown. sender ( ) ;
782+ let acceptor = RPC :: start_acceptor ( rpc_addr, acceptor_shutdown) ;
783+ RPC :: start_notifier (
784+ notification,
785+ senders. clone ( ) ,
786+ acceptor. sender ( ) ,
787+ acceptor_shutdown_sender,
788+ ) ;
731789
732790 let mut threads = HashMap :: new ( ) ;
733791 let ( garbage_sender, garbage_receiver) = crossbeam_channel:: unbounded ( ) ;
@@ -738,6 +796,11 @@ impl RPC {
738796 let senders = Arc :: clone ( & senders) ;
739797 let stats = Arc :: clone ( & stats) ;
740798 let garbage_sender = garbage_sender. clone ( ) ;
799+
800+ // Kill the peers properly
801+ let ( killer, peace_receiver) = std:: sync:: mpsc:: channel ( ) ;
802+ killers. lock ( ) . unwrap ( ) . push ( killer) ;
803+
741804 #[ cfg( feature = "electrum-discovery" ) ]
742805 let discovery = discovery. clone ( ) ;
743806
@@ -749,6 +812,7 @@ impl RPC {
749812 addr,
750813 stats,
751814 txs_limit,
815+ peace_receiver,
752816 #[ cfg( feature = "electrum-discovery" ) ]
753817 discovery,
754818 ) ;
@@ -769,10 +833,16 @@ impl RPC {
769833 }
770834 }
771835 }
836+ // Drop these
837+ drop ( acceptor) ;
838+ drop ( garbage_receiver) ;
772839
773840 trace ! ( "closing {} RPC connections" , senders. lock( ) . unwrap( ) . len( ) ) ;
841+ for killer in killers. lock ( ) . unwrap ( ) . iter ( ) {
842+ let _ = killer. send ( ( ) ) ;
843+ }
774844 for sender in senders. lock ( ) . unwrap ( ) . iter ( ) {
775- let _ = sender. send ( Message :: Done ) ;
845+ let _ = sender. try_send ( Message :: Done ) ;
776846 }
777847
778848 for ( id, thread) in threads {
@@ -800,5 +870,8 @@ impl Drop for RPC {
800870 handle. join ( ) . unwrap ( ) ;
801871 }
802872 trace ! ( "RPC server is stopped" ) ;
873+ crate :: util:: with_spawned_threads ( |threads| {
874+ trace ! ( "Threads after dropping RPC: {:?}" , threads) ;
875+ } ) ;
803876 }
804877}
0 commit comments