@@ -584,7 +584,7 @@ fn timeout_param_to_duration(timeout_ms: u64) -> Option<Duration> {
584584 }
585585}
586586
587- async fn make_burst_call ( opts : & Opts ) -> Result {
587+ async fn make_burst_call ( opts : & Opts , shutdown_receiver : Receiver < ( ) > ) -> Result {
588588 if opts. method . is_none ( ) {
589589 return Err ( "--method parameter missing" . into ( ) ) ;
590590 }
@@ -610,16 +610,41 @@ async fn make_burst_call(opts: &Opts) -> Result {
610610 taskno : i32 ,
611611 count : i32 ,
612612 timeout : Option < Duration > ,
613- user_agent : String
613+ user_agent : String ,
614+ shutdown_receiver : Receiver < ( ) > ,
614615 ) {
615616 println ! ( "Starting burst task #{taskno}, {count} calls of {path}:{method}" ) ;
616- let ( mut frame_reader, mut frame_writer) = login ( & url, user_agent) . await . unwrap ( ) ;
617+ let ( mut frame_reader, mut frame_writer) = match login ( & url, user_agent) . await {
618+ Ok ( conn) => conn,
619+ Err ( err) => {
620+ error ! ( target: "Burst" , "Burst task #{taskno} failed to login: {err}" ) ;
621+ return ;
622+ }
623+ } ;
617624 for _ in 0 ..count {
625+ if shutdown_receiver. try_recv ( ) . is_ok ( ) {
626+ info ! ( target: "Burst" , "Shutdown requested, stopping burst task #{taskno}" ) ;
627+ return ;
628+ }
618629 let rqid = frame_writer
619630 . send_request ( & path, & method, param. clone ( ) )
620631 . await
621632 . unwrap ( ) ;
622- receive_response ( & mut frame_reader, rqid, timeout) . await . unwrap ( ) ;
633+ let response_fut = receive_response ( & mut frame_reader, rqid, timeout) . fuse ( ) ;
634+ let shutdown_fut = shutdown_receiver. recv ( ) . fuse ( ) ;
635+ futures:: pin_mut!( response_fut, shutdown_fut) ;
636+ match select ( response_fut, shutdown_fut) . await {
637+ futures:: future:: Either :: Left ( ( response, _) ) => {
638+ if let Err ( err) = response {
639+ error ! ( target: "Burst" , "Burst task #{taskno} request failed: {err}" ) ;
640+ return ;
641+ }
642+ }
643+ futures:: future:: Either :: Right ( ( _, _) ) => {
644+ info ! ( target: "Burst" , "Shutdown requested, stopping burst task #{taskno}" ) ;
645+ return ;
646+ }
647+ }
623648 }
624649 println ! ( "Burst task #{taskno} finished, after {count} calls made successfully." ) ;
625650 }
@@ -635,7 +660,8 @@ async fn make_burst_call(opts: &Opts) -> Result {
635660 taskno + 1 ,
636661 nmsg,
637662 timeout_param_to_duration ( opts. timeout ) ,
638- opts. extract_user_agent ( )
663+ opts. extract_user_agent ( ) ,
664+ shutdown_receiver. clone ( ) ,
639665 ) )
640666 } )
641667 . collect :: < FuturesUnordered < _ > > ( )
@@ -660,14 +686,48 @@ fn split_quoted(s: &str) -> Vec<&str> {
660686}
661687#[ derive( Debug ) ]
662688struct Tunnel {
689+ tunid : Option < u64 > ,
663690 create_rqid : RqId ,
664691 write_rqid : RqId ,
692+ close_rqid : Option < RqId > ,
665693 frame_sender : Sender < RpcFrame > ,
666694}
695+
696+ enum TunnelEvent {
697+ ClientConnectionClosed ( u64 ) ,
698+ }
699+
700+ async fn send_tunnel_close ( tunnel_path : & str , tunid : u64 , write_frame_sender : & mut Sender < RpcFrame > ) -> Option < RqId > {
701+ let mut rq = RpcMessage :: new_request ( format ! ( "{tunnel_path}/{tunid}" ) , "close" ) ;
702+ let rqid = RpcMessage :: next_request_id ( ) ;
703+ rq. set_request_id ( rqid) ;
704+ match rq. to_frame ( ) {
705+ Ok ( frame) => {
706+ if let Err ( err) = write_frame_sender. send ( frame) . await {
707+ error ! ( target: "Tunnel" , "Failed to send tunnel close request: {err}" ) ;
708+ None
709+ } else {
710+ Some ( rqid)
711+ }
712+ }
713+ Err ( err) => {
714+ error ! ( target: "Tunnel" , "Failed to build tunnel close request: {err}" ) ;
715+ None
716+ }
717+ }
718+ }
719+
720+ fn remove_pre_tunid_closed_tunnels ( tunnels : & mut Vec < Tunnel > ) {
721+ tunnels. extract_if ( .., |tunnel| tunnel. tunid . is_none ( ) && tunnel. frame_sender . is_closed ( ) ) . for_each ( |tunnel| {
722+ debug ! ( target: "Tunnel" , "Removing tunnel task that finished before tunid assignment: {:?}" , tunnel) ;
723+ } ) ;
724+ }
725+
667726async fn start_tunnel_server (
668727 mut broker_frame_reader : BoxedFrameReader ,
669728 mut broker_frame_writer : BoxedFrameWriter ,
670729 opts : & Opts ,
730+ shutdown_receiver : Receiver < ( ) > ,
671731) -> Result {
672732 if opts. tunnel_path . is_none ( ) {
673733 warn ! ( "Using default .app/tunnel endpoint. This is usually not what you want. Set tunnel path to the broker you want to create a tunnel to." ) ;
@@ -701,30 +761,48 @@ async fn start_tunnel_server(
701761 let local_host = local_host. to_owned ( ) ;
702762
703763 let mut tunnels: Vec < Tunnel > = Vec :: new ( ) ;
764+ let mut tunnel_tasks: Vec < smol:: Task < ( ) > > = Vec :: new ( ) ;
765+ let mut shutting_down = false ;
704766
705767 debug ! ( target: "Tunnel" , "Starting TCP server on {local_host}:{local_port}" ) ;
706768 let listener = TcpListener :: bind ( format ! ( "{local_host}:{local_port}" ) ) . await ?;
707769 let mut incoming = listener. incoming ( ) ;
708770
709- let ( write_frame_sender, write_frame_receiver) = async_channel:: unbounded ( ) ;
771+ let ( tunnel_event_sender, tunnel_event_receiver) = async_channel:: unbounded ( ) ;
772+ let ( mut write_frame_sender, write_frame_receiver) = async_channel:: unbounded ( ) ;
773+ let ( shutdown_timeout_sender, shutdown_timeout_receiver) = async_channel:: bounded :: < ( ) > ( 1 ) ;
710774 loop {
775+ remove_pre_tunid_closed_tunnels ( & mut tunnels) ;
776+ tunnel_tasks. retain ( |task| !task. is_finished ( ) ) ;
777+ if shutting_down && tunnels. iter ( ) . all ( |tunnel| tunnel. close_rqid . is_none ( ) ) {
778+ info ! ( target: "Tunnel" , "All tunnel close responses received, finishing tunnel loop" ) ;
779+ break ;
780+ }
711781 select ! {
712782 stream = incoming. next( ) . fuse( ) => {
713783 if let Some ( stream) = stream {
784+ if shutting_down {
785+ drop( stream?) ;
786+ continue ;
787+ }
714788 let stream = stream?;
715789 debug!( target: "Tunnel" , "New connection from {:?}" , stream. local_addr( ) ) ;
716790 let create_rqid = RpcMessage :: next_request_id( ) ;
717791 let write_rqid = RpcMessage :: next_request_id( ) ;
718792 let ( read_frame_sender, read_frame_receiver) = async_channel:: unbounded( ) ;
719- let tunnel = Tunnel { create_rqid, write_rqid, frame_sender: read_frame_sender} ;
793+ let tunnel = Tunnel { tunid : None , create_rqid, write_rqid, close_rqid : None , frame_sender: read_frame_sender} ;
720794 tunnels. push( tunnel) ;
721795 let read_frame_receiver = read_frame_receiver. clone( ) ;
722796 let write_frame_sender = write_frame_sender. clone( ) ;
723797 let remote_host_port = remote_host_port. clone( ) ;
724798 let tunnel_path = tunnel_path. clone( ) ;
725- spawn_and_log_error( async move {
726- handle_tunnel_socket( stream, remote_host_port, tunnel_path, create_rqid, write_rqid, read_frame_receiver, write_frame_sender. clone( ) ) . await . map_err( |e | e. to_string( ) )
799+ let tunnel_event_sender = tunnel_event_sender. clone( ) ;
800+ let task = smol:: spawn( async move {
801+ if let Err ( err) = handle_tunnel_socket( stream, remote_host_port, tunnel_path, create_rqid, write_rqid, read_frame_receiver, write_frame_sender. clone( ) , tunnel_event_sender) . await {
802+ error!( target: "Tunnel" , "Tunnel task finished with error: {err}" ) ;
803+ }
727804 } ) ;
805+ tunnel_tasks. push( task) ;
728806 } else {
729807 break ;
730808 }
@@ -733,22 +811,88 @@ async fn start_tunnel_server(
733811 match frame {
734812 Ok ( frame) => {
735813 let rqid = frame. request_id( ) . unwrap_or( 0 ) ;
736- for tunnel in & tunnels {
814+ if tunnels. iter( ) . any( |tunnel| tunnel. close_rqid == Some ( rqid) ) {
815+ tunnels. extract_if( .., |tunnel| tunnel. close_rqid == Some ( rqid) ) . for_each( |tunnel| {
816+ debug!( target: "Tunnel" , "Tunnel close ACK received, removing tunnel {:?}" , tunnel) ;
817+ } ) ;
818+ if shutting_down && tunnels. iter( ) . all( |tunnel| tunnel. close_rqid. is_none( ) ) {
819+ info!( target: "Tunnel" , "All tunnel close responses received, finishing tunnel loop" ) ;
820+ break ;
821+ }
822+ continue ;
823+ }
824+ for tunnel in & mut tunnels {
737825 if tunnel. write_rqid == rqid || tunnel. create_rqid == rqid {
826+ if tunnel. create_rqid == rqid && tunnel. tunid. is_none( ) {
827+ if let Ok ( rpcmsg) = frame. to_rpcmesage( )
828+ && let Ok ( shvrpc:: rpcmessage:: Response :: Success ( val) ) = rpcmsg. response( )
829+ && let Ok ( tunid) = val. as_str( ) . parse:: <u64 >( )
830+ {
831+ tunnel. tunid = Some ( tunid) ;
832+ }
833+ }
738834 tunnel. frame_sender. send( frame) . await ?;
739835 break ;
740836 }
741837 }
742- tunnels. extract_if( .., |tunnel| tunnel. frame_sender. is_closed( ) ) . for_each( |tunnel| {
743- debug!( target: "Tunnel" , "Removing closed tunnel {:?}" , tunnel) ;
744- } ) ;
745838 }
746839 Err ( e) => {
747840 error!( "Get response receiver error: {e}" ) ;
748841 break ;
749842 }
750843 }
751844 }
845+ event = tunnel_event_receiver. recv( ) . fuse( ) => {
846+ match event {
847+ Ok ( TunnelEvent :: ClientConnectionClosed ( tunid) ) => {
848+ if let Some ( tunnel) = tunnels. iter_mut( ) . find( |tunnel| tunnel. tunid == Some ( tunid) ) {
849+ if tunnel. close_rqid. is_none( ) {
850+ if let Some ( close_rqid) = send_tunnel_close( & tunnel_path, tunid, & mut write_frame_sender) . await {
851+ tunnel. close_rqid = Some ( close_rqid) ;
852+ } else {
853+ tunnels. extract_if( .., |candidate| candidate. tunid == Some ( tunid) ) . for_each( |removed| {
854+ debug!( target: "Tunnel" , "Tunnel close send failed, removing tunnel {:?}" , removed) ;
855+ } ) ;
856+ }
857+ }
858+ }
859+ }
860+ Err ( err) => {
861+ debug!( target: "Tunnel" , "Tunnel event receiver closed: {err}" ) ;
862+ }
863+ }
864+ }
865+ shutdown = shutdown_receiver. recv( ) . fuse( ) => {
866+ if shutdown. is_ok( ) && !shutting_down {
867+ info!( target: "Tunnel" , "Received shutdown signal, shutting down TCP tunnel server" ) ;
868+ shutting_down = true ;
869+ let shutdown_timeout_sender = shutdown_timeout_sender. clone( ) ;
870+ smol:: spawn( async move {
871+ smol:: Timer :: after( std:: time:: Duration :: from_secs( 2 ) ) . await ;
872+ let _ = shutdown_timeout_sender. send( ( ) ) . await ;
873+ } ) . detach( ) ;
874+ if tunnels. is_empty( ) {
875+ info!( target: "Tunnel" , "No active tunnels, finishing tunnel loop" ) ;
876+ break ;
877+ }
878+ for tunnel in & mut tunnels {
879+ if tunnel. close_rqid. is_some( ) {
880+ continue ;
881+ }
882+ if let Some ( tunid) = tunnel. tunid {
883+ if let Some ( rqid) = send_tunnel_close( & tunnel_path, tunid, & mut write_frame_sender) . await {
884+ tunnel. close_rqid = Some ( rqid) ;
885+ }
886+ }
887+ }
888+ }
889+ }
890+ timeout = shutdown_timeout_receiver. recv( ) . fuse( ) => {
891+ if shutting_down && timeout. is_ok( ) {
892+ info!( target: "Tunnel" , "Shutdown deadline reached, finishing tunnel loop" ) ;
893+ break ;
894+ }
895+ }
752896 frame = write_frame_receiver. recv( ) . fuse( ) => {
753897 match frame {
754898 Ok ( frame) => {
@@ -762,10 +906,17 @@ async fn start_tunnel_server(
762906 }
763907 }
764908 }
909+
910+ for task in tunnel_tasks. drain ( ..) {
911+ if !task. is_finished ( ) {
912+ task. cancel ( ) . await ;
913+ }
914+ }
915+
765916 Ok ( ( ) )
766917}
767918
768- async fn handle_tunnel_socket ( stream : TcpStream , remote_host_port : String , tunnel_path : String , create_rqid : RqId , write_rqid : RqId , read_frame_receiver : Receiver < RpcFrame > , mut write_frame_sender : Sender < RpcFrame > ) -> Result {
919+ async fn handle_tunnel_socket ( stream : TcpStream , remote_host_port : String , tunnel_path : String , create_rqid : RqId , write_rqid : RqId , read_frame_receiver : Receiver < RpcFrame > , mut write_frame_sender : Sender < RpcFrame > , tunnel_event_sender : Sender < TunnelEvent > ) -> Result {
769920 let tunid = {
770921 debug ! ( target: "Tunnel" , "Creating tunnel" ) ;
771922 let tun_opts = Map :: from ( [ ( "host" . into ( ) , ( remote_host_port) . into ( ) ) ] ) ;
@@ -799,20 +950,30 @@ async fn handle_tunnel_socket(stream: TcpStream, remote_host_port: String, tunne
799950 rq. set_seqno ( seqno_to_write) ;
800951 seqno_to_write += 1 ;
801952 debug ! ( target: "Tunnel" , "Starting data exchange" ) ;
802- write_frame_sender. send ( rq. to_frame ( ) ?) . await ?;
953+ if let Err ( err) = write_frame_sender. send ( rq. to_frame ( ) ?) . await {
954+ return Err ( err. into ( ) ) ;
955+ }
803956 } ;
957+
804958 let ( mut sock_reader, mut sock_writer) = stream. split ( ) ;
805959 let mut sock_read_buff: [ u8 ; 1024 ] = [ 0 ; 1024 ] ;
806960 loop {
807961 select ! {
808962 n = sock_reader. read( & mut sock_read_buff) . fuse( ) => {
809- let n = n?;
810- if n == 0 {
811- debug!( target: "Tunnel" , "Tunnel client socket closed" ) ;
812- break ;
963+ match n {
964+ Ok ( n) => {
965+ if n == 0 {
966+ debug!( target: "Tunnel" , "Tunnel client socket closed" ) ;
967+ let _ = tunnel_event_sender. send( TunnelEvent :: ClientConnectionClosed ( tunid) ) . await ;
968+ return Ok ( ( ) ) ;
969+ }
970+ let data = & sock_read_buff[ 0 .. n] ;
971+ seqno_to_write = process_socket_to_broker_data( & tunnel_path, tunid, seqno_to_write, write_rqid, data, & mut write_frame_sender) . await ?;
972+ }
973+ Err ( err) => {
974+ return Err ( err. into( ) ) ;
975+ }
813976 }
814- let data = & sock_read_buff[ 0 .. n] ;
815- seqno_to_write = process_socket_to_broker_data( & tunnel_path, tunid, seqno_to_write, write_rqid, data, & mut write_frame_sender) . await ?;
816977 }
817978 frame = read_frame_receiver. recv( ) . fuse( ) => {
818979 match frame {
@@ -866,14 +1027,40 @@ async fn process_socket_to_broker_data(tunnel_path: &str, tunid: u64, seqno_to_w
8661027}
8671028
8681029pub async fn try_main ( opts : Opts ) -> Result {
1030+ let ( shutdown_sender, shutdown_receiver) = async_channel:: bounded :: < ( ) > ( 1 ) ;
1031+ smol:: spawn ( async move {
1032+ match async_signal:: Signals :: new ( & [
1033+ async_signal:: Signal :: Term ,
1034+ async_signal:: Signal :: Int ,
1035+ ] ) {
1036+ Ok ( mut signals) => {
1037+ if signals. next ( ) . await . is_some ( ) {
1038+ let _ = shutdown_sender. send ( ( ) ) . await ;
1039+ }
1040+ }
1041+ Err ( err) => {
1042+ error ! ( "Failed to initialize signal handling: {err}" ) ;
1043+ }
1044+ }
1045+ } ) . detach ( ) ;
1046+
8691047 if opts. burst . is_some ( ) {
870- return make_burst_call ( & opts) . await ;
1048+ return make_burst_call ( & opts, shutdown_receiver ) . await ;
8711049 }
8721050 let ( frame_reader, frame_writer) = login ( & opts. url , opts. extract_user_agent ( ) ) . await ?;
8731051 let res = if opts. tunnel . is_some ( ) {
874- start_tunnel_server ( frame_reader, frame_writer, & opts) . await
1052+ start_tunnel_server ( frame_reader, frame_writer, & opts, shutdown_receiver ) . await
8751053 } else {
876- make_call ( frame_reader, frame_writer, & opts) . await
1054+ let call_fut = make_call ( frame_reader, frame_writer, & opts) . fuse ( ) ;
1055+ let shutdown_fut = shutdown_receiver. recv ( ) . fuse ( ) ;
1056+ futures:: pin_mut!( call_fut, shutdown_fut) ;
1057+ match select ( call_fut, shutdown_fut) . await {
1058+ futures:: future:: Either :: Left ( ( call_res, _) ) => call_res,
1059+ futures:: future:: Either :: Right ( ( _, _) ) => {
1060+ info ! ( "Received shutdown signal, exiting" ) ;
1061+ Ok ( ( ) )
1062+ }
1063+ }
8771064 } ;
8781065 match res {
8791066 Ok ( _) => Ok ( ( ) ) ,
0 commit comments