@@ -13,9 +13,11 @@ mod sniffer;
1313use std:: net:: IpAddr ;
1414use std:: sync:: atomic:: AtomicBool ;
1515use std:: sync:: Arc ;
16+ use std:: time:: Duration ;
1617
1718use tracing:: { error, info} ;
1819use tracing_subscriber:: EnvFilter ;
20+ use tokio_util:: sync:: CancellationToken ;
1921
2022fn main ( ) {
2123 tracing_subscriber:: fmt ( )
@@ -91,6 +93,7 @@ fn main() {
9193 let ( cmd_tx, cmd_rx) = std:: sync:: mpsc:: channel :: < proto:: SnifferCommand > ( ) ;
9294
9395 let stop = Arc :: new ( AtomicBool :: new ( false ) ) ;
96+ let token = CancellationToken :: new ( ) ;
9497
9598 let sniffer_stop = stop. clone ( ) ;
9699 let sniffer_local_ips = local_ips. clone ( ) ;
@@ -103,24 +106,35 @@ fn main() {
103106 . expect ( "failed to spawn sniffer thread" ) ;
104107
105108 let rt = tokio:: runtime:: Runtime :: new ( ) . expect ( "failed to create tokio runtime" ) ;
109+ let graceful_shutdown_sec = cfg. graceful_shutdown_sec ;
106110 rt. block_on ( async {
107- let signal_stop = stop. clone ( ) ;
108- tokio:: spawn ( async move {
109- shutdown:: wait_for_signal ( signal_stop) . await ;
110- tokio:: time:: sleep ( std:: time:: Duration :: from_secs ( 1 ) ) . await ;
111- std:: process:: exit ( 0 ) ;
112- } ) ;
113-
114111 let mut handles = Vec :: new ( ) ;
115112 for lc in cfg. listeners {
116113 let tx = cmd_tx. clone ( ) ;
117114 let lip = resolve_local_ip ( lc. connect . ip ( ) ) . unwrap_or ( local_ips[ 0 ] ) ;
118- handles. push ( tokio:: spawn ( listener:: run_listener ( lc, lip, tx, cfg. idle_timeout , cfg. buffer_size ) ) ) ;
115+ let tk = token. clone ( ) ;
116+ handles. push ( tokio:: spawn ( listener:: run_listener ( lc, lip, tx, cfg. idle_timeout , cfg. buffer_size , tk) ) ) ;
119117 }
120118
121- for h in handles {
122- let _ = h. await ;
119+ shutdown:: wait_for_signal ( stop, token) . await ;
120+
121+ if graceful_shutdown_sec == 0 {
122+ info ! ( "graceful_shutdown_sec=0, exiting immediately" ) ;
123+ } else {
124+ info ! ( "waiting up to {}s for active connections to drain" , graceful_shutdown_sec) ;
125+
126+ let drain_all = async {
127+ for h in handles {
128+ let _ = h. await ;
129+ }
130+ } ;
131+
132+ if tokio:: time:: timeout ( Duration :: from_secs ( graceful_shutdown_sec) , drain_all) . await . is_err ( ) {
133+ info ! ( "drain timeout ({}s), forcing exit" , graceful_shutdown_sec) ;
134+ }
123135 }
136+
137+ info ! ( "shutdown complete" ) ;
124138 } ) ;
125139}
126140
0 commit comments