1- use std:: { convert:: Infallible , net:: SocketAddr , time:: Duration } ;
1+ use std:: {
2+ convert:: Infallible ,
3+ net:: SocketAddr ,
4+ pin:: Pin ,
5+ task:: { Context , Poll } ,
6+ time:: Duration ,
7+ } ;
28
3- use futures:: FutureExt ;
9+ use futures:: { FutureExt , StreamExt } ;
410use http:: { Request , Response } ;
511use hyper:: Body ;
12+ use tokio:: {
13+ io:: { AsyncRead , AsyncWrite , ReadBuf } ,
14+ net:: TcpStream ,
15+ time:: { Sleep , sleep} ,
16+ } ;
617use tonic:: {
718 body:: BoxBody ,
819 server:: NamedService ,
9- transport:: server:: { Routes , Server } ,
20+ transport:: server:: { Connected , Routes , Server } ,
1021} ;
1122use tower:: Service ;
1223use tower_http:: {
@@ -18,16 +29,127 @@ use tracing::Span;
1829use crate :: {
1930 internal_events:: { GrpcServerRequestReceived , GrpcServerResponseSent } ,
2031 shutdown:: { ShutdownSignal , ShutdownSignalToken } ,
21- tls:: MaybeTlsSettings ,
32+ tls:: { MaybeTlsIncomingStream , MaybeTlsSettings } ,
2233} ;
34+ use vector_lib:: configurable:: configurable_component;
2335
2436mod decompression;
2537pub use self :: decompression:: { DecompressionAndMetrics , DecompressionAndMetricsLayer } ;
2638
39+ /// Configuration of gRPC server keepalive parameters.
40+ #[ configurable_component]
41+ #[ derive( Clone , Debug , Default , PartialEq , Eq ) ]
42+ #[ serde( deny_unknown_fields) ]
43+ pub struct GrpcKeepaliveConfig {
44+ /// The maximum amount of time a connection may exist before the server closes it.
45+ ///
46+ /// When unset, connections are not closed based on age.
47+ #[ serde( default ) ]
48+ #[ configurable( metadata( docs:: examples = 300 ) ) ]
49+ #[ configurable( metadata( docs:: type_unit = "seconds" ) ) ]
50+ #[ configurable( metadata( docs:: human_name = "Maximum Connection Age" ) ) ]
51+ pub max_connection_age_secs : Option < u64 > ,
52+
53+ /// The grace period added to `max_connection_age_secs` before the server closes the connection.
54+ ///
55+ /// This setting only applies when `max_connection_age_secs` is set.
56+ #[ serde( default ) ]
57+ #[ configurable( metadata( docs:: examples = 30 ) ) ]
58+ #[ configurable( metadata( docs:: type_unit = "seconds" ) ) ]
59+ #[ configurable( metadata( docs:: human_name = "Maximum Connection Age Grace" ) ) ]
60+ pub max_connection_age_grace_secs : Option < u64 > ,
61+ }
62+
63+ impl GrpcKeepaliveConfig {
64+ fn max_connection_lifetime ( & self ) -> Option < Duration > {
65+ self . max_connection_age_secs . map ( |max_connection_age_secs| {
66+ let age = Duration :: from_secs ( max_connection_age_secs) ;
67+ let grace = self
68+ . max_connection_age_grace_secs
69+ . map ( Duration :: from_secs)
70+ . unwrap_or_default ( ) ;
71+
72+ age. checked_add ( grace) . unwrap_or ( Duration :: MAX )
73+ } )
74+ }
75+ }
76+
77+ struct MaxConnectionAgeIo {
78+ inner : MaybeTlsIncomingStream < TcpStream > ,
79+ deadline : Option < Pin < Box < Sleep > > > ,
80+ }
81+
82+ impl MaxConnectionAgeIo {
83+ fn new ( inner : MaybeTlsIncomingStream < TcpStream > , lifetime : Option < Duration > ) -> Self {
84+ Self {
85+ inner,
86+ deadline : lifetime. map ( |lifetime| Box :: pin ( sleep ( lifetime) ) ) ,
87+ }
88+ }
89+
90+ fn is_expired ( & mut self , cx : & mut Context < ' _ > ) -> bool {
91+ self . deadline
92+ . as_mut ( )
93+ . is_some_and ( |deadline| deadline. as_mut ( ) . poll ( cx) . is_ready ( ) )
94+ }
95+ }
96+
97+ impl AsyncRead for MaxConnectionAgeIo {
98+ fn poll_read (
99+ self : Pin < & mut Self > ,
100+ cx : & mut Context < ' _ > ,
101+ buf : & mut ReadBuf < ' _ > ,
102+ ) -> Poll < std:: io:: Result < ( ) > > {
103+ let this = self . get_mut ( ) ;
104+ if this. is_expired ( cx) {
105+ Poll :: Ready ( Ok ( ( ) ) )
106+ } else {
107+ Pin :: new ( & mut this. inner ) . poll_read ( cx, buf)
108+ }
109+ }
110+ }
111+
112+ impl AsyncWrite for MaxConnectionAgeIo {
113+ fn poll_write (
114+ self : Pin < & mut Self > ,
115+ cx : & mut Context < ' _ > ,
116+ buf : & [ u8 ] ,
117+ ) -> Poll < std:: io:: Result < usize > > {
118+ let this = self . get_mut ( ) ;
119+ if this. is_expired ( cx) {
120+ Poll :: Ready ( Err ( std:: io:: ErrorKind :: BrokenPipe . into ( ) ) )
121+ } else {
122+ Pin :: new ( & mut this. inner ) . poll_write ( cx, buf)
123+ }
124+ }
125+
126+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
127+ let this = self . get_mut ( ) ;
128+ if this. is_expired ( cx) {
129+ Poll :: Ready ( Err ( std:: io:: ErrorKind :: BrokenPipe . into ( ) ) )
130+ } else {
131+ Pin :: new ( & mut this. inner ) . poll_flush ( cx)
132+ }
133+ }
134+
135+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
136+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_shutdown ( cx)
137+ }
138+ }
139+
140+ impl Connected for MaxConnectionAgeIo {
141+ type ConnectInfo = <MaybeTlsIncomingStream < TcpStream > as Connected >:: ConnectInfo ;
142+
143+ fn connect_info ( & self ) -> Self :: ConnectInfo {
144+ self . inner . connect_info ( )
145+ }
146+ }
147+
27148pub async fn run_grpc_server < S > (
28149 address : SocketAddr ,
29150 tls_settings : MaybeTlsSettings ,
30151 service : S ,
152+ keepalive : GrpcKeepaliveConfig ,
31153 shutdown : ShutdownSignal ,
32154) -> crate :: Result < ( ) >
33155where
@@ -41,7 +163,10 @@ where
41163 let span = Span :: current ( ) ;
42164 let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < ShutdownSignalToken > ( ) ;
43165 let listener = tls_settings. bind ( & address) . await ?;
44- let stream = listener. accept_stream ( ) ;
166+ let max_connection_lifetime = keepalive. max_connection_lifetime ( ) ;
167+ let stream = listener
168+ . accept_stream ( )
169+ . map ( move |stream| stream. map ( |io| MaxConnectionAgeIo :: new ( io, max_connection_lifetime) ) ) ;
45170
46171 info ! ( %address, "Building gRPC server." ) ;
47172
@@ -72,12 +197,16 @@ pub async fn run_grpc_server_with_routes(
72197 address : SocketAddr ,
73198 tls_settings : MaybeTlsSettings ,
74199 routes : Routes ,
200+ keepalive : GrpcKeepaliveConfig ,
75201 shutdown : ShutdownSignal ,
76202) -> crate :: Result < ( ) > {
77203 let span = Span :: current ( ) ;
78204 let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < ShutdownSignalToken > ( ) ;
79205 let listener = tls_settings. bind ( & address) . await ?;
80- let stream = listener. accept_stream ( ) ;
206+ let max_connection_lifetime = keepalive. max_connection_lifetime ( ) ;
207+ let stream = listener
208+ . accept_stream ( )
209+ . map ( move |stream| stream. map ( |io| MaxConnectionAgeIo :: new ( io, max_connection_lifetime) ) ) ;
81210
82211 info ! ( %address, "Building gRPC server." ) ;
83212
0 commit comments