Skip to content

Commit 9193232

Browse files
author
Martin Špirk
committed
handle SIGTERM
1 parent c547b7b commit 9193232

4 files changed

Lines changed: 214 additions & 24 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
/.idea
3+
/.qtcreator

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ futures-rustls = "0.26.0"
3030
rustls-pemfile = "2.2.0"
3131
rustls-platform-verifier = "0.7.0"
3232
async-channel = "2.3.1"
33+
async-signal = "0.2.13"
3334
serialport = { version = "4.7.1", optional = true }
3435

3536
# For local development

src/lib.rs

Lines changed: 211 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ fn timeout_param_to_duration(timeout_ms: u64) -> Option<Duration> {
584584
}
585585
}
586586

587-
async fn make_burst_call(opts: &Opts) -> Result {
587+
async fn make_burst_call(opts: &Opts, shutdown_receiver: Receiver<()>) -> Result {
588588
if opts.method.is_none() {
589589
return Err("--method parameter missing".into());
590590
}
@@ -610,16 +610,41 @@ async fn make_burst_call(opts: &Opts) -> Result {
610610
taskno: i32,
611611
count: i32,
612612
timeout: Option<Duration>,
613-
user_agent: String
613+
user_agent: String,
614+
shutdown_receiver: Receiver<()>,
614615
) {
615616
println!("Starting burst task #{taskno}, {count} calls of {path}:{method}");
616-
let (mut frame_reader, mut frame_writer) = login(&url, user_agent).await.unwrap();
617+
let (mut frame_reader, mut frame_writer) = match login(&url, user_agent).await {
618+
Ok(conn) => conn,
619+
Err(err) => {
620+
error!(target: "Burst", "Burst task #{taskno} failed to login: {err}");
621+
return;
622+
}
623+
};
617624
for _ in 0..count {
625+
if shutdown_receiver.try_recv().is_ok() {
626+
info!(target: "Burst", "Shutdown requested, stopping burst task #{taskno}");
627+
return;
628+
}
618629
let rqid = frame_writer
619630
.send_request(&path, &method, param.clone())
620631
.await
621632
.unwrap();
622-
receive_response(&mut frame_reader, rqid, timeout).await.unwrap();
633+
let response_fut = receive_response(&mut frame_reader, rqid, timeout).fuse();
634+
let shutdown_fut = shutdown_receiver.recv().fuse();
635+
futures::pin_mut!(response_fut, shutdown_fut);
636+
match select(response_fut, shutdown_fut).await {
637+
futures::future::Either::Left((response, _)) => {
638+
if let Err(err) = response {
639+
error!(target: "Burst", "Burst task #{taskno} request failed: {err}");
640+
return;
641+
}
642+
}
643+
futures::future::Either::Right((_, _)) => {
644+
info!(target: "Burst", "Shutdown requested, stopping burst task #{taskno}");
645+
return;
646+
}
647+
}
623648
}
624649
println!("Burst task #{taskno} finished, after {count} calls made successfully.");
625650
}
@@ -635,7 +660,8 @@ async fn make_burst_call(opts: &Opts) -> Result {
635660
taskno + 1,
636661
nmsg,
637662
timeout_param_to_duration(opts.timeout),
638-
opts.extract_user_agent()
663+
opts.extract_user_agent(),
664+
shutdown_receiver.clone(),
639665
))
640666
})
641667
.collect::<FuturesUnordered<_>>()
@@ -660,14 +686,48 @@ fn split_quoted(s: &str) -> Vec<&str> {
660686
}
661687
#[derive(Debug)]
662688
struct Tunnel {
689+
tunid: Option<u64>,
663690
create_rqid: RqId,
664691
write_rqid: RqId,
692+
close_rqid: Option<RqId>,
665693
frame_sender: Sender<RpcFrame>,
666694
}
695+
696+
enum TunnelEvent {
697+
ClientConnectionClosed(u64),
698+
}
699+
700+
async fn send_tunnel_close(tunnel_path: &str, tunid: u64, write_frame_sender: &mut Sender<RpcFrame>) -> Option<RqId> {
701+
let mut rq = RpcMessage::new_request(format!("{tunnel_path}/{tunid}"), "close");
702+
let rqid = RpcMessage::next_request_id();
703+
rq.set_request_id(rqid);
704+
match rq.to_frame() {
705+
Ok(frame) => {
706+
if let Err(err) = write_frame_sender.send(frame).await {
707+
error!(target: "Tunnel", "Failed to send tunnel close request: {err}");
708+
None
709+
} else {
710+
Some(rqid)
711+
}
712+
}
713+
Err(err) => {
714+
error!(target: "Tunnel", "Failed to build tunnel close request: {err}");
715+
None
716+
}
717+
}
718+
}
719+
720+
fn remove_pre_tunid_closed_tunnels(tunnels: &mut Vec<Tunnel>) {
721+
tunnels.extract_if(.., |tunnel| tunnel.tunid.is_none() && tunnel.frame_sender.is_closed()).for_each(|tunnel| {
722+
debug!(target: "Tunnel", "Removing tunnel task that finished before tunid assignment: {:?}", tunnel);
723+
});
724+
}
725+
667726
async fn start_tunnel_server(
668727
mut broker_frame_reader: BoxedFrameReader,
669728
mut broker_frame_writer: BoxedFrameWriter,
670729
opts: &Opts,
730+
shutdown_receiver: Receiver<()>,
671731
) -> Result {
672732
if opts.tunnel_path.is_none() {
673733
warn!("Using default .app/tunnel endpoint. This is usually not what you want. Set tunnel path to the broker you want to create a tunnel to.");
@@ -701,30 +761,48 @@ async fn start_tunnel_server(
701761
let local_host = local_host.to_owned();
702762

703763
let mut tunnels: Vec<Tunnel> = Vec::new();
764+
let mut tunnel_tasks: Vec<smol::Task<()>> = Vec::new();
765+
let mut shutting_down = false;
704766

705767
debug!(target: "Tunnel", "Starting TCP server on {local_host}:{local_port}");
706768
let listener = TcpListener::bind(format!("{local_host}:{local_port}")).await?;
707769
let mut incoming = listener.incoming();
708770

709-
let (write_frame_sender, write_frame_receiver) = async_channel::unbounded();
771+
let (tunnel_event_sender, tunnel_event_receiver) = async_channel::unbounded();
772+
let (mut write_frame_sender, write_frame_receiver) = async_channel::unbounded();
773+
let (shutdown_timeout_sender, shutdown_timeout_receiver) = async_channel::bounded::<()>(1);
710774
loop {
775+
remove_pre_tunid_closed_tunnels(&mut tunnels);
776+
tunnel_tasks.retain(|task| !task.is_finished());
777+
if shutting_down && tunnels.iter().all(|tunnel| tunnel.close_rqid.is_none()) {
778+
info!(target: "Tunnel", "All tunnel close responses received, finishing tunnel loop");
779+
break;
780+
}
711781
select! {
712782
stream = incoming.next().fuse() => {
713783
if let Some(stream) = stream {
784+
if shutting_down {
785+
drop(stream?);
786+
continue;
787+
}
714788
let stream = stream?;
715789
debug!(target: "Tunnel", "New connection from {:?}", stream.local_addr());
716790
let create_rqid = RpcMessage::next_request_id();
717791
let write_rqid = RpcMessage::next_request_id();
718792
let (read_frame_sender, read_frame_receiver) = async_channel::unbounded();
719-
let tunnel = Tunnel {create_rqid, write_rqid, frame_sender: read_frame_sender};
793+
let tunnel = Tunnel {tunid: None, create_rqid, write_rqid, close_rqid: None, frame_sender: read_frame_sender};
720794
tunnels.push(tunnel);
721795
let read_frame_receiver = read_frame_receiver.clone();
722796
let write_frame_sender = write_frame_sender.clone();
723797
let remote_host_port = remote_host_port.clone();
724798
let tunnel_path = tunnel_path.clone();
725-
spawn_and_log_error(async move {
726-
handle_tunnel_socket(stream, remote_host_port, tunnel_path, create_rqid, write_rqid, read_frame_receiver, write_frame_sender.clone()).await.map_err(|e | e.to_string())
799+
let tunnel_event_sender = tunnel_event_sender.clone();
800+
let task = smol::spawn(async move {
801+
if let Err(err) = handle_tunnel_socket(stream, remote_host_port, tunnel_path, create_rqid, write_rqid, read_frame_receiver, write_frame_sender.clone(), tunnel_event_sender).await {
802+
error!(target: "Tunnel", "Tunnel task finished with error: {err}");
803+
}
727804
});
805+
tunnel_tasks.push(task);
728806
} else {
729807
break;
730808
}
@@ -733,22 +811,88 @@ async fn start_tunnel_server(
733811
match frame {
734812
Ok(frame) => {
735813
let rqid = frame.request_id().unwrap_or(0);
736-
for tunnel in &tunnels {
814+
if tunnels.iter().any(|tunnel| tunnel.close_rqid == Some(rqid)) {
815+
tunnels.extract_if(.., |tunnel| tunnel.close_rqid == Some(rqid)).for_each(|tunnel| {
816+
debug!(target: "Tunnel", "Tunnel close ACK received, removing tunnel {:?}", tunnel);
817+
});
818+
if shutting_down && tunnels.iter().all(|tunnel| tunnel.close_rqid.is_none()) {
819+
info!(target: "Tunnel", "All tunnel close responses received, finishing tunnel loop");
820+
break;
821+
}
822+
continue;
823+
}
824+
for tunnel in &mut tunnels {
737825
if tunnel.write_rqid == rqid || tunnel.create_rqid == rqid {
826+
if tunnel.create_rqid == rqid && tunnel.tunid.is_none() {
827+
if let Ok(rpcmsg) = frame.to_rpcmesage()
828+
&& let Ok(shvrpc::rpcmessage::Response::Success(val)) = rpcmsg.response()
829+
&& let Ok(tunid) = val.as_str().parse::<u64>()
830+
{
831+
tunnel.tunid = Some(tunid);
832+
}
833+
}
738834
tunnel.frame_sender.send(frame).await?;
739835
break;
740836
}
741837
}
742-
tunnels.extract_if(.., |tunnel| tunnel.frame_sender.is_closed()).for_each(|tunnel| {
743-
debug!(target: "Tunnel", "Removing closed tunnel {:?}", tunnel);
744-
});
745838
}
746839
Err(e) => {
747840
error!("Get response receiver error: {e}");
748841
break;
749842
}
750843
}
751844
}
845+
event = tunnel_event_receiver.recv().fuse() => {
846+
match event {
847+
Ok(TunnelEvent::ClientConnectionClosed(tunid)) => {
848+
if let Some(tunnel) = tunnels.iter_mut().find(|tunnel| tunnel.tunid == Some(tunid)) {
849+
if tunnel.close_rqid.is_none() {
850+
if let Some(close_rqid) = send_tunnel_close(&tunnel_path, tunid, &mut write_frame_sender).await {
851+
tunnel.close_rqid = Some(close_rqid);
852+
} else {
853+
tunnels.extract_if(.., |candidate| candidate.tunid == Some(tunid)).for_each(|removed| {
854+
debug!(target: "Tunnel", "Tunnel close send failed, removing tunnel {:?}", removed);
855+
});
856+
}
857+
}
858+
}
859+
}
860+
Err(err) => {
861+
debug!(target: "Tunnel", "Tunnel event receiver closed: {err}");
862+
}
863+
}
864+
}
865+
shutdown = shutdown_receiver.recv().fuse() => {
866+
if shutdown.is_ok() && !shutting_down {
867+
info!(target: "Tunnel", "Received shutdown signal, shutting down TCP tunnel server");
868+
shutting_down = true;
869+
let shutdown_timeout_sender = shutdown_timeout_sender.clone();
870+
smol::spawn(async move {
871+
smol::Timer::after(std::time::Duration::from_secs(2)).await;
872+
let _ = shutdown_timeout_sender.send(()).await;
873+
}).detach();
874+
if tunnels.is_empty() {
875+
info!(target: "Tunnel", "No active tunnels, finishing tunnel loop");
876+
break;
877+
}
878+
for tunnel in &mut tunnels {
879+
if tunnel.close_rqid.is_some() {
880+
continue;
881+
}
882+
if let Some(tunid) = tunnel.tunid {
883+
if let Some(rqid) = send_tunnel_close(&tunnel_path, tunid, &mut write_frame_sender).await {
884+
tunnel.close_rqid = Some(rqid);
885+
}
886+
}
887+
}
888+
}
889+
}
890+
timeout = shutdown_timeout_receiver.recv().fuse() => {
891+
if shutting_down && timeout.is_ok() {
892+
info!(target: "Tunnel", "Shutdown deadline reached, finishing tunnel loop");
893+
break;
894+
}
895+
}
752896
frame = write_frame_receiver.recv().fuse() => {
753897
match frame {
754898
Ok(frame) => {
@@ -762,10 +906,17 @@ async fn start_tunnel_server(
762906
}
763907
}
764908
}
909+
910+
for task in tunnel_tasks.drain(..) {
911+
if !task.is_finished() {
912+
task.cancel().await;
913+
}
914+
}
915+
765916
Ok(())
766917
}
767918

768-
async fn handle_tunnel_socket(stream: TcpStream, remote_host_port: String, tunnel_path: String, create_rqid: RqId, write_rqid: RqId, read_frame_receiver: Receiver<RpcFrame>, mut write_frame_sender: Sender<RpcFrame>) -> Result {
919+
async fn handle_tunnel_socket(stream: TcpStream, remote_host_port: String, tunnel_path: String, create_rqid: RqId, write_rqid: RqId, read_frame_receiver: Receiver<RpcFrame>, mut write_frame_sender: Sender<RpcFrame>, tunnel_event_sender: Sender<TunnelEvent>) -> Result {
769920
let tunid = {
770921
debug!(target: "Tunnel", "Creating tunnel");
771922
let tun_opts = Map::from([("host".into(), (remote_host_port).into())]);
@@ -799,20 +950,30 @@ async fn handle_tunnel_socket(stream: TcpStream, remote_host_port: String, tunne
799950
rq.set_seqno(seqno_to_write);
800951
seqno_to_write += 1;
801952
debug!(target: "Tunnel", "Starting data exchange");
802-
write_frame_sender.send(rq.to_frame()?).await?;
953+
if let Err(err) = write_frame_sender.send(rq.to_frame()?).await {
954+
return Err(err.into());
955+
}
803956
};
957+
804958
let (mut sock_reader, mut sock_writer) = stream.split();
805959
let mut sock_read_buff: [u8; 1024] = [0; 1024];
806960
loop {
807961
select! {
808962
n = sock_reader.read(&mut sock_read_buff).fuse() => {
809-
let n = n?;
810-
if n == 0 {
811-
debug!(target: "Tunnel", "Tunnel client socket closed");
812-
break;
963+
match n {
964+
Ok(n) => {
965+
if n == 0 {
966+
debug!(target: "Tunnel", "Tunnel client socket closed");
967+
let _ = tunnel_event_sender.send(TunnelEvent::ClientConnectionClosed(tunid)).await;
968+
return Ok(());
969+
}
970+
let data = &sock_read_buff[0 .. n];
971+
seqno_to_write = process_socket_to_broker_data(&tunnel_path, tunid, seqno_to_write, write_rqid, data, &mut write_frame_sender).await?;
972+
}
973+
Err(err) => {
974+
return Err(err.into());
975+
}
813976
}
814-
let data = &sock_read_buff[0 .. n];
815-
seqno_to_write = process_socket_to_broker_data(&tunnel_path, tunid, seqno_to_write, write_rqid, data, &mut write_frame_sender).await?;
816977
}
817978
frame = read_frame_receiver.recv().fuse() => {
818979
match frame {
@@ -866,14 +1027,40 @@ async fn process_socket_to_broker_data(tunnel_path: &str, tunid: u64, seqno_to_w
8661027
}
8671028

8681029
pub async fn try_main(opts: Opts) -> Result {
1030+
let (shutdown_sender, shutdown_receiver) = async_channel::bounded::<()>(1);
1031+
smol::spawn(async move {
1032+
match async_signal::Signals::new(&[
1033+
async_signal::Signal::Term,
1034+
async_signal::Signal::Int,
1035+
]) {
1036+
Ok(mut signals) => {
1037+
if signals.next().await.is_some() {
1038+
let _ = shutdown_sender.send(()).await;
1039+
}
1040+
}
1041+
Err(err) => {
1042+
error!("Failed to initialize signal handling: {err}");
1043+
}
1044+
}
1045+
}).detach();
1046+
8691047
if opts.burst.is_some() {
870-
return make_burst_call(&opts).await;
1048+
return make_burst_call(&opts, shutdown_receiver).await;
8711049
}
8721050
let (frame_reader, frame_writer) = login(&opts.url, opts.extract_user_agent()).await?;
8731051
let res = if opts.tunnel.is_some() {
874-
start_tunnel_server(frame_reader, frame_writer, &opts).await
1052+
start_tunnel_server(frame_reader, frame_writer, &opts, shutdown_receiver).await
8751053
} else {
876-
make_call(frame_reader, frame_writer, &opts).await
1054+
let call_fut = make_call(frame_reader, frame_writer, &opts).fuse();
1055+
let shutdown_fut = shutdown_receiver.recv().fuse();
1056+
futures::pin_mut!(call_fut, shutdown_fut);
1057+
match select(call_fut, shutdown_fut).await {
1058+
futures::future::Either::Left((call_res, _)) => call_res,
1059+
futures::future::Either::Right((_, _)) => {
1060+
info!("Received shutdown signal, exiting");
1061+
Ok(())
1062+
}
1063+
}
8771064
};
8781065
match res {
8791066
Ok(_) => Ok(()),

0 commit comments

Comments
 (0)