@@ -7,26 +7,26 @@ use std::{
77 os:: unix:: fs:: PermissionsExt ,
88 path:: PathBuf ,
99 sync:: {
10- atomic:: { AtomicU32 , AtomicU64 , Ordering } ,
1110 Arc ,
11+ atomic:: { AtomicU32 , AtomicU64 , Ordering } ,
1212 } ,
1313 time:: { Duration , Instant } ,
1414} ;
1515
16- use anyhow:: { anyhow , Context } ;
16+ use anyhow:: { Context , anyhow } ;
1717use async_trait:: async_trait;
18- use axum:: { extract:: Request , Router } ;
18+ use axum:: { Router , extract:: Request } ;
1919use http:: Response ;
2020use hyper:: body:: Incoming ;
2121use hyper_util:: {
2222 rt:: { TokioExecutor , TokioIo , TokioTimer } ,
2323 server:: conn:: auto:: Builder ,
2424} ;
2525use prometheus:: {
26- register_histogram_vec_with_registry , register_int_counter_vec_with_registry ,
27- register_int_gauge_vec_with_registry , HistogramVec , IntCounterVec , IntGaugeVec , Registry ,
26+ HistogramVec , IntCounterVec , IntGaugeVec , Registry , register_histogram_vec_with_registry ,
27+ register_int_counter_vec_with_registry , register_int_gauge_vec_with_registry ,
2828} ;
29- use rustls:: { server :: ServerConnection , CipherSuite , ProtocolVersion } ;
29+ use rustls:: { CipherSuite , ProtocolVersion , server :: ServerConnection } ;
3030use tokio:: {
3131 io:: { AsyncRead , AsyncWrite , AsyncWriteExt } ,
3232 net:: { TcpListener , TcpSocket , UnixListener , UnixSocket } ,
@@ -41,7 +41,7 @@ use tower_service::Service;
4141use tracing:: { debug, info, warn} ;
4242use uuid:: Uuid ;
4343
44- use super :: { body :: NotifyingBody , AsyncCounter , Error , Stats , ALPN_ACME } ;
44+ use super :: { ALPN_ACME , AsyncCounter , Error , Stats , body :: NotifyingBody } ;
4545use crate :: tasks:: Run ;
4646
4747const HANDSHAKE_DURATION_BUCKETS : & [ f64 ] = & [ 0.005 , 0.01 , 0.02 , 0.05 , 0.1 , 0.2 , 0.4 , 0.8 , 1.6 ] ;
@@ -322,6 +322,27 @@ enum RequestState {
322322 End ,
323323}
324324
325+ async fn tls_handshake (
326+ rustls_cfg : Arc < rustls:: ServerConfig > ,
327+ stream : impl AsyncReadWrite ,
328+ ) -> Result < ( impl AsyncReadWrite , TlsInfo ) , Error > {
329+ let tls_acceptor = TlsAcceptor :: from ( rustls_cfg) ;
330+
331+ // Perform the TLS handshake
332+ let start = Instant :: now ( ) ;
333+ let stream = tls_acceptor
334+ . accept ( stream)
335+ . await
336+ . context ( "TLS accept failed" ) ?;
337+ let duration = start. elapsed ( ) ;
338+
339+ let conn = stream. get_ref ( ) . 1 ;
340+ let mut tls_info = TlsInfo :: try_from ( conn) ?;
341+ tls_info. handshake_dur = duration;
342+
343+ Ok ( ( stream, tls_info) )
344+ }
345+
325346struct Conn {
326347 addr : Addr ,
327348 remote_addr : Addr ,
@@ -332,7 +353,7 @@ struct Conn {
332353 options : Options ,
333354 metrics : Metrics ,
334355 requests : AtomicU32 ,
335- tls_acceptor : Option < TlsAcceptor > ,
356+ rustls_cfg : Option < Arc < rustls :: ServerConfig > > ,
336357}
337358
338359impl Display for Conn {
@@ -342,40 +363,6 @@ impl Display for Conn {
342363}
343364
344365impl Conn {
345- async fn tls_handshake (
346- & self ,
347- stream : impl AsyncReadWrite ,
348- ) -> Result < ( impl AsyncReadWrite , TlsInfo ) , Error > {
349- debug ! ( "{}: performing TLS handshake" , self ) ;
350-
351- // Perform the TLS handshake
352- let start = Instant :: now ( ) ;
353- let stream = self
354- . tls_acceptor
355- . as_ref ( )
356- . unwrap ( ) // Caller makes sure it's Some()
357- . accept ( stream)
358- . await
359- . context ( "TLS accept failed" ) ?;
360- let duration = start. elapsed ( ) ;
361-
362- let conn = stream. get_ref ( ) . 1 ;
363- let mut tls_info = TlsInfo :: try_from ( conn) ?;
364- tls_info. handshake_dur = duration;
365-
366- debug ! (
367- "{}: handshake finished in {}ms (server: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})" ,
368- self ,
369- duration. as_millis( ) ,
370- tls_info. sni,
371- tls_info. protocol,
372- tls_info. cipher,
373- tls_info. alpn,
374- ) ;
375-
376- Ok ( ( stream, tls_info) )
377- }
378-
379366 async fn handle ( & self , stream : Box < dyn AsyncReadWrite > ) -> Result < ( ) , Error > {
380367 let accepted_at = Instant :: now ( ) ;
381368
@@ -405,15 +392,29 @@ impl Conn {
405392 } ) ;
406393
407394 // Perform TLS handshake if we're in TLS mode
408- let ( stream, tls_info) : ( Box < dyn AsyncReadWrite > , _ ) = if self . tls_acceptor . is_some ( ) {
409- let ( mut stream, tls_info) = timeout (
395+ let ( stream, tls_info) : ( Box < dyn AsyncReadWrite > , _ ) = if let Some ( rustls_cfg) =
396+ & self . rustls_cfg
397+ {
398+ debug ! ( "{}: performing TLS handshake" , self ) ;
399+
400+ let ( mut stream_tls, tls_info) = timeout (
410401 self . options . tls_handshake_timeout ,
411- self . tls_handshake ( stream) ,
402+ tls_handshake ( rustls_cfg . clone ( ) , stream) ,
412403 )
413404 . await
414405 . context ( "TLS handshake timed out" ) ?
415406 . context ( "TLS handshake failed" ) ?;
416407
408+ debug ! (
409+ "{}: handshake finished in {}ms (server: {:?}, proto: {:?}, cipher: {:?}, ALPN: {:?})" ,
410+ self ,
411+ tls_info. handshake_dur. as_millis( ) ,
412+ tls_info. sni,
413+ tls_info. protocol,
414+ tls_info. cipher,
415+ tls_info. alpn,
416+ ) ;
417+
417418 // Close the connection if agreed ALPN is ACME - the handshake is enough for the challenge
418419 if tls_info
419420 . alpn
@@ -422,15 +423,15 @@ impl Conn {
422423 {
423424 debug ! ( "{self}: ACME ALPN - closing connection" ) ;
424425
425- timeout ( Duration :: from_secs ( 5 ) , stream . shutdown ( ) )
426+ timeout ( Duration :: from_secs ( 5 ) , stream_tls . shutdown ( ) )
426427 . await
427428 . context ( "socket shutdown timed out" ) ?
428429 . context ( "socket shutdown failed" ) ?;
429430
430431 return Ok ( ( ) ) ;
431432 }
432433
433- ( Box :: new ( stream ) , Some ( Arc :: new ( tls_info) ) )
434+ ( Box :: new ( stream_tls ) , Some ( Arc :: new ( tls_info) ) )
434435 } else {
435436 ( Box :: new ( stream) , None )
436437 } ;
@@ -648,7 +649,7 @@ pub struct Server {
648649 options : Options ,
649650 metrics : Metrics ,
650651 builder : Builder < TokioExecutor > ,
651- tls_acceptor : Option < TlsAcceptor > ,
652+ rustls_cfg : Option < Arc < rustls :: ServerConfig > > ,
652653}
653654
654655impl Server {
@@ -681,7 +682,7 @@ impl Server {
681682 metrics,
682683 tracker : TaskTracker :: new ( ) ,
683684 builder,
684- tls_acceptor : rustls_cfg. map ( |x| TlsAcceptor :: from ( Arc :: new ( x ) ) ) ,
685+ rustls_cfg : rustls_cfg. map ( Arc :: new) ,
685686 }
686687 }
687688
@@ -719,7 +720,7 @@ impl Server {
719720 options : self . options ,
720721 metrics : self . metrics . clone ( ) , // All metrics have Arc inside
721722 requests : AtomicU32 :: new ( 0 ) ,
722- tls_acceptor : self . tls_acceptor . clone ( ) ,
723+ rustls_cfg : self . rustls_cfg . clone ( ) ,
723724 } ;
724725
725726 // Spawn a task to handle connection & track it
@@ -743,7 +744,7 @@ impl Server {
743744 warn ! (
744745 "Server {}: running (TLS: {})" ,
745746 self . addr,
746- self . tls_acceptor . is_some( )
747+ self . rustls_cfg . is_some( )
747748 ) ;
748749
749750 loop {
0 commit comments