1- use crate :: {
2- hooks:: { Hook , Hooks } ,
3- recorder:: get_or_init_prometheus,
4- version:: VersionInfo ,
5- } ;
1+ use crate :: { hooks:: Hooks , recorder:: get_or_init_prometheus, version:: VersionInfo } ;
2+ use axum:: { extract:: State , response:: IntoResponse , routing:: get, Router } ;
63use eyre:: WrapErr ;
74use metrics_process:: Collector ;
8- use std:: { net:: SocketAddr , sync :: Arc } ;
5+ use std:: net:: SocketAddr ;
96use tokio:: {
10- io:: AsyncWriteExt ,
11- select,
12- signal:: {
13- ctrl_c,
14- unix:: { signal, SignalKind } ,
15- } ,
167 spawn,
17- sync:: oneshot:: { self , Sender } ,
8+ sync:: {
9+ broadcast,
10+ oneshot:: { self , Sender } ,
11+ } ,
1812} ;
1913use tracing:: { debug, error, info, warn} ;
2014
@@ -31,17 +25,30 @@ pub struct MetricServerConfig {
3125 ready_signal : Option < Sender < ( ) > > ,
3226}
3327
28+ impl Clone for MetricServerConfig {
29+ fn clone ( & self ) -> Self {
30+ Self {
31+ listen_addr : self . listen_addr ,
32+ version_info : self . version_info . clone ( ) ,
33+ hooks : self . hooks . clone ( ) ,
34+ service_name : self . service_name . clone ( ) ,
35+ ready_signal : None ,
36+ }
37+ }
38+ }
39+
3440/// The metrics server must be initialized and running BEFORE any metrics are recorded.
3541/// This ensures proper registration with the Prometheus recorder.
3642///
3743/// Example usage:
3844/// ```ignore
3945/// // 1. Create and spawn metrics server first
4046/// let (ready_tx, ready_rx) = oneshot::channel();
47+ /// let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
4148/// let config = MetricServerConfig::new(addr, version_info, "my-service".to_string())
4249/// .with_ready_signal(ready_tx);
4350/// let server = MetricServer::new(config);
44- /// let handle = tokio::spawn(async move { server.serve().await });
51+ /// let handle = tokio::spawn(async move { server.serve(shutdown_rx ).await });
4552///
4653/// // 2. Wait for server to be ready
4754/// ready_rx.await?;
@@ -68,7 +75,7 @@ impl MetricServerConfig {
6875}
6976
7077/// [`MetricServer`] responsible for serving the metrics endpoint.
71- #[ derive( Debug ) ]
78+ #[ derive( Debug , Clone ) ]
7279pub struct MetricServer {
7380 config : MetricServerConfig ,
7481}
@@ -79,21 +86,16 @@ impl MetricServer {
7986 Self { config }
8087 }
8188
82- /// Spawns the metrics server.
83- pub async fn serve ( self ) -> eyre :: Result < ( ) > {
84- let ( shutdown_tx , shutdown_rx ) = oneshot :: channel ( ) ;
85-
86- // Clone hooks before creating the closure.
87- let hooks = self . config . hooks . clone ( ) ;
89+ /// Spawns the metrics server with an external shutdown signal .
90+ ///
91+ /// This version of serve takes a broadcast receiver that can be used to trigger
92+ /// shutdown from the outside, avoiding race conditions with signal handlers.
93+ pub async fn serve ( self , mut shutdown_signal : broadcast :: Receiver < ( ) > ) -> eyre :: Result < ( ) > {
94+ let ( internal_shutdown_tx , internal_shutdown_rx ) = oneshot :: channel ( ) ;
8895
8996 // Start the endpoint before moving out ready_signal.
9097 let server_handle = self
91- . start_endpoint (
92- self . config . listen_addr ,
93- self . config . service_name . clone ( ) ,
94- Arc :: new ( move || hooks. iter ( ) . for_each ( |hook| hook ( ) ) ) ,
95- shutdown_rx,
96- )
98+ . start_endpoint ( internal_shutdown_rx)
9799 . await
98100 . wrap_err ( "could not start prometheus endpoint" ) ?;
99101
@@ -111,24 +113,14 @@ impl MetricServer {
111113 self . config . version_info . register_version_metrics ( ) ;
112114 describe_io_stats ( ) ;
113115
114- // Handle shutdown signals.
115- spawn ( async move {
116- select ! {
117- _ = ctrl_c( ) => {
118- info!( "ctrl-c received, initiating graceful shutdown..." ) ;
119- }
120- _ = async {
121- signal( SignalKind :: terminate( ) )
122- . expect( "failed to install signal handler" )
123- . recv( )
124- . await ;
125- } => {
126- info!( "SIGTERM received, initiating graceful shutdown..." ) ;
116+ // Listen for the external shutdown signal
117+ tokio:: spawn ( async move {
118+ if shutdown_signal. recv ( ) . await . is_ok ( ) {
119+ info ! ( "received external shutdown signal, initiating graceful shutdown..." ) ;
120+ if internal_shutdown_tx. send ( ( ) ) . is_err ( ) {
121+ warn ! ( "failed to send shutdown signal to metrics server" ) ;
127122 }
128123 }
129- if shutdown_tx. send ( ( ) ) . is_err ( ) {
130- warn ! ( "failed to send shutdown signal to metrics server" ) ;
131- }
132124 } ) ;
133125
134126 // Wait for the server to complete.
@@ -138,67 +130,46 @@ impl MetricServer {
138130 Ok ( ( ) )
139131 }
140132
141- async fn start_endpoint < F : Hook + ' static > (
133+ async fn start_endpoint (
142134 & self ,
143- listen_addr : SocketAddr ,
144- service_name : String ,
145- hook : Arc < F > ,
146- mut shutdown_rx : oneshot:: Receiver < ( ) > ,
135+ shutdown_rx : oneshot:: Receiver < ( ) > ,
147136 ) -> eyre:: Result < tokio:: task:: JoinHandle < ( ) > > {
148- let listener = tokio:: net:: TcpListener :: bind ( listen_addr)
149- . await
150- . wrap_err ( "could not bind to address" ) ?;
151-
152137 // Initialize the prometheus recorder.
153- get_or_init_prometheus ( & service_name) ;
138+ get_or_init_prometheus ( & self . config . service_name ) ;
139+
140+ let app = Router :: new ( )
141+ . route ( "/" , get ( Self :: metrics_handler) )
142+ . route ( "/metrics" , get ( Self :: metrics_handler) )
143+ . with_state ( self . clone ( ) ) ;
144+
145+ let listen_addr = self . config . listen_addr ;
154146 info ! ( "metrics server listening on {}" , listen_addr) ;
155147
156148 // Spawn a task to accept connections.
157149 Ok ( spawn ( async move {
158- loop {
159- select ! {
160- // Handle shutdown signal.
161- _ = & mut shutdown_rx => {
150+ // Use axum's built-in server functionality with simplified shutdown
151+ if let Err ( err) =
152+ axum:: serve ( tokio:: net:: TcpListener :: bind ( listen_addr) . await . unwrap ( ) , app)
153+ . with_graceful_shutdown ( async {
154+ let _ = shutdown_rx. await ;
162155 info ! ( "shutdown signal received for metrics server" ) ;
163- break ;
164- }
165- // Accept incoming connections.
166- accept_result = listener. accept( ) => {
167- match accept_result {
168- Ok ( ( mut stream, _remote_addr) ) => {
169- let hook = hook. clone( ) ;
170- let service_name = service_name. clone( ) ;
171-
172- // Spawn a new task to handle the connection.
173- spawn( async move {
174- ( hook) ( ) ;
175- let handle = get_or_init_prometheus( & service_name) ;
176- let metrics = handle. render( ) ;
177-
178- let response = format!(
179- "HTTP/1.1 200 OK\r \n Content-Type: text/plain; charset=utf-8\r \n Content-Length: {}\r \n \r \n {}" ,
180- metrics. len( ) ,
181- metrics
182- ) ;
183-
184- if let Err ( err) = stream. write_all( response. as_bytes( ) ) . await {
185- error!( %err, "failed to write response" ) ;
186- }
187- if let Err ( err) = stream. flush( ) . await {
188- error!( %err, "failed to flush response" ) ;
189- }
190- } ) ;
191- }
192- Err ( err) => {
193- error!( %err, "failed to accept connection" ) ;
194- continue ;
195- }
196- }
197- }
198- }
156+ } )
157+ . await
158+ {
159+ error ! ( %err, "metrics server error" ) ;
199160 }
200161 } ) )
201162 }
163+
164+ /// Handler for the metrics endpoint.
165+ async fn metrics_handler ( State ( server) : State < Self > ) -> impl IntoResponse {
166+ // Execute all hooks
167+ server. config . hooks . iter ( ) . for_each ( |hook| hook ( ) ) ;
168+
169+ // Get metrics from prometheus
170+ let handle = get_or_init_prometheus ( & server. config . service_name ) ;
171+ handle. render ( )
172+ }
202173}
203174
204175#[ cfg( target_os = "linux" ) ]
@@ -248,9 +219,12 @@ mod tests {
248219 let listen_addr = get_random_available_addr ( ) ;
249220 let config = MetricServerConfig :: new ( listen_addr, version_info, "test" . to_string ( ) ) ;
250221
222+ // Create a shutdown channel for the server
223+ let ( _shutdown_tx, shutdown_rx) = broadcast:: channel ( 1 ) ;
224+
251225 // Start server in separate task
252226 let server = MetricServer :: new ( config) ;
253- let server_handle = tokio:: spawn ( async move { server. serve ( ) . await } ) ;
227+ let server_handle = tokio:: spawn ( async move { server. serve ( shutdown_rx ) . await } ) ;
254228
255229 // Give the server a moment to start
256230 tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 100 ) ) . await ;
0 commit comments