diff --git a/src/server/tracker/accept.rs b/src/server/tracker/accept.rs index 8ea767e..43c0795 100644 --- a/src/server/tracker/accept.rs +++ b/src/server/tracker/accept.rs @@ -40,6 +40,7 @@ where let (mut stream, service) = acceptor.accept(TlsInspector::new(stream), service).await?; let mut connect_track = ConnectionTrack::default(); connect_track.set_client_hello(stream.get_mut().0.client_hello()); + connect_track.set_tls_version_negotiated(stream.get_ref().1.protocol_version()); let stream = match stream.get_ref().1.alpn_protocol() { // If ALPN is set to HTTP/2, use Http2Inspector diff --git a/src/server/tracker/info.rs b/src/server/tracker/info.rs index aed8631..0fc28f0 100644 --- a/src/server/tracker/info.rs +++ b/src/server/tracker/info.rs @@ -5,6 +5,7 @@ use axum::{ http::{header::USER_AGENT, HeaderValue, Method, Request}, }; use serde::{Serialize, Serializer}; +use tokio_rustls::rustls::ProtocolVersion; use super::inspector::{ClientHello, Frame, Http1Headers, Http2Frame, LazyClientHello}; @@ -28,6 +29,8 @@ pub struct Http2TrackInfo { /// Collects TLS, HTTP/1, and HTTP/2 handshake info for tracking. #[derive(Clone, Default)] pub struct ConnectionTrack { + /// The TLS protocol version that was negotiated for this connection, if any. + tls_version_negotiated: Option, client_hello: Option, http1_headers: Option, http2_frames: Option, @@ -216,6 +219,12 @@ where // ==== impl ConnectionTrack ==== impl ConnectionTrack { + /// Set TLS version negotiated during the handshake. + #[inline] + pub fn set_tls_version_negotiated(&mut self, version: Option) { + self.tls_version_negotiated = version; + } + /// Set TLS client hello #[inline] pub fn set_client_hello(&mut self, client_hello: Option) { @@ -248,17 +257,23 @@ impl TrackInfo { req: Request, connection_track: ConnectionTrack, ) -> TrackInfo { - let headers = req.headers(); + let mut tls = connection_track + .client_hello + .and_then(LazyClientHello::parse) + .map(TlsTrackInfo::new); + + if let Some(tls) = tls.as_mut() { + tls.0 + .set_tls_version_negotiated(connection_track.tls_version_negotiated); + } + let track_info = TrackInfo { donate: Self::DONATE_URL, address: addr, http_version: format!("{:?}", req.version()), method: req.method().clone(), - user_agent: headers.get(USER_AGENT).cloned(), - tls: connection_track - .client_hello - .and_then(LazyClientHello::parse) - .map(TlsTrackInfo::new), + user_agent: req.headers().get(USER_AGENT).cloned(), + tls, http1: connection_track.http1_headers.map(Http1TrackInfo::new), http2: connection_track.http2_frames.and_then(Http2TrackInfo::new), }; diff --git a/src/server/tracker/inspector/tls/hello.rs b/src/server/tracker/inspector/tls/hello.rs index d23a469..48e2b9f 100644 --- a/src/server/tracker/inspector/tls/hello.rs +++ b/src/server/tracker/inspector/tls/hello.rs @@ -2,6 +2,7 @@ use serde::Serialize; use tls_parser::{TlsCipherSuite, TlsExtensionType, TlsMessage, TlsMessageHandshake}; +use tokio_rustls::rustls::ProtocolVersion; use super::{ enums::{ @@ -54,6 +55,8 @@ impl LazyClientHello { pub struct ClientHello { /// TLS version of message tls_version: TlsVersion, + /// The final TLS version negotiated during the handshake + tls_version_negotiated: Option, client_random: String, session_id: Option, /// A list of compression methods supported by client @@ -233,6 +236,15 @@ pub struct OidFilter { } impl ClientHello { + /// Sets the negotiated TLS version for this `ClientHello`. + /// + /// # Parameters + /// - `version`: An `Option` representing the negotiated TLS version. + /// If `Some`, the version is set; if `None`, no version was negotiated. + pub fn set_tls_version_negotiated(&mut self, version: Option) { + self.tls_version_negotiated = version.map(u16::from).map(TlsVersion::from); + } + pub fn parse(buf: &[u8]) -> Option { let (_, r) = tls_parser::parse_tls_raw_record(buf).ok()?; let (_, msg_list) = tls_parser::parse_tls_record_with_header(r.data, &r.hdr).ok()?; @@ -247,6 +259,7 @@ impl ClientHello { let mut client_hello = ClientHello { tls_version: TlsVersion::from(payload.version.0), + tls_version_negotiated: None, client_random: hex::encode(payload.random), session_id: payload.session_id.map(hex::encode), compression_algorithms: payload