diff --git a/gel-stream/Cargo.toml b/gel-stream/Cargo.toml index 9c315dd6..016e0a5b 100644 --- a/gel-stream/Cargo.toml +++ b/gel-stream/Cargo.toml @@ -50,7 +50,7 @@ hickory-resolver = { version = "0.25.2", optional = true, default-features = fal # feature = "rustls" # We rely on certain aspects of these crates. Use caution when upgrading. rustls = { version = ">= 0.23.25", optional = true, default-features = false, features = ["ring", "logging", "std", "tls12"] } -rustls-tokio-stream = { version = "0.6.0", optional = true } +rustls-tokio-stream = { version = "0.8.0", optional = true } rustls-platform-verifier = { version = "0.5.1", optional = true } webpki = { version = "0.22", optional = true } webpki-roots = { version = "1", optional = true } diff --git a/gel-stream/src/common/openssl.rs b/gel-stream/src/common/openssl.rs index 66a193ac..c7a33d30 100644 --- a/gel-stream/src/common/openssl.rs +++ b/gel-stream/src/common/openssl.rs @@ -289,9 +289,6 @@ impl TlsDriver for OpensslDriver { let stream = stream .downcast::() .map_err(|_| crate::SslError::SslUnsupported)?; - let TokioStream::Tcp(stream) = stream else { - return Err(crate::SslError::SslUnsupported); - }; let mut stream = tokio_openssl::SslStream::new(params, Box::new(stream) as Box)?; diff --git a/gel-stream/src/common/rustls.rs b/gel-stream/src/common/rustls.rs index 86342e97..34d8ba8f 100644 --- a/gel-stream/src/common/rustls.rs +++ b/gel-stream/src/common/rustls.rs @@ -9,7 +9,7 @@ use rustls_pki_types::{ CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime, }; use rustls_platform_verifier::Verifier; -use rustls_tokio_stream::TlsStream; +use rustls_tokio_stream::{TlsStream, UnderlyingStream}; use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; use super::tokio_stream::TokioStream; @@ -23,12 +23,13 @@ use std::borrow::Cow; use std::mem::MaybeUninit; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; +use std::task::{Context, Poll}; #[derive(Default)] pub struct RustlsDriver; impl TlsDriver for RustlsDriver { - type Stream = TlsStream; + type Stream = TlsStream; type ClientParams = ClientConnection; type ServerParams = Arc; const DRIVER_NAME: &'static str = "rustls"; @@ -134,9 +135,6 @@ impl TlsDriver for RustlsDriver { let stream = stream .downcast::() .map_err(|_| crate::SslError::SslUnsupported)?; - let TokioStream::Tcp(stream) = stream else { - return Err(crate::SslError::SslUnsupported); - }; let mut stream = TlsStream::new_client_side(stream, params, None); match stream.handshake().await { @@ -180,7 +178,7 @@ impl TlsDriver for RustlsDriver { params: TlsServerParameterProvider, stream: S, ) -> Result<(Self::Stream, TlsHandshake), SslError> { - let (stream, mut acceptor) = match stream.downcast::>() { + let (mut stream, mut acceptor) = match stream.downcast::>() { Ok(stream) => { let (stream, buffer) = stream.into_inner(); let mut acceptor = Acceptor::default(); @@ -195,10 +193,6 @@ impl TlsDriver for RustlsDriver { } }; - let TokioStream::Tcp(mut stream) = stream else { - return Err(crate::SslError::SslUnsupported); - }; - let mut buf = [MaybeUninit::uninit(); 1024]; let accepted = loop { match acceptor.accept() { @@ -638,19 +632,29 @@ impl ServerCertVerifier for ErrorFilteringVerifier { } } -impl LocalAddress for TlsStream { +impl LocalAddress for TlsStream { fn local_address(&self) -> std::io::Result { - self.local_addr().map(ResolvedTarget::from) + self.underlying_stream() + .ok_or(std::io::Error::new( + std::io::ErrorKind::Other, + "No underlying stream", + ))? + .local_address() } } -impl RemoteAddress for TlsStream { +impl RemoteAddress for TlsStream { fn remote_address(&self) -> std::io::Result { - self.peer_addr().map(ResolvedTarget::from) + self.underlying_stream() + .ok_or(std::io::Error::new( + std::io::ErrorKind::Other, + "No underlying stream", + ))? + .remote_address() } } -impl PeerCred for TlsStream { +impl PeerCred for TlsStream { #[cfg(all(unix, feature = "tokio"))] fn peer_cred(&self) -> std::io::Result { Err(std::io::Error::new( @@ -660,20 +664,86 @@ impl PeerCred for TlsStream { } } -impl StreamMetadata for TlsStream { +impl StreamMetadata for TlsStream { fn transport(&self) -> Transport { Transport::Tcp } } -impl AsHandle for TlsStream { +impl AsHandle for TlsStream { #[cfg(windows)] fn as_handle(&self) -> std::os::windows::io::BorrowedSocket { - std::os::windows::io::AsSocket::as_socket(self.tcp_stream().unwrap()) + std::os::windows::io::AsSocket::as_socket(self.underlying_stream().unwrap()) } #[cfg(unix)] fn as_fd(&self) -> std::os::fd::BorrowedFd { - std::os::fd::AsFd::as_fd(self.tcp_stream().unwrap()) + std::os::fd::AsFd::as_fd(self.underlying_stream().unwrap()) + } +} + +impl UnderlyingStream for TokioStream { + type StdType = (); + + async fn readable(&self) -> std::io::Result<()> { + match self { + TokioStream::Tcp(stream) => stream.readable().await, + #[cfg(unix)] + TokioStream::Unix(stream) => stream.readable().await, + } + } + async fn writable(&self) -> std::io::Result<()> { + match self { + TokioStream::Tcp(stream) => stream.writable().await, + #[cfg(unix)] + TokioStream::Unix(stream) => stream.writable().await, + } + } + fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + match self { + TokioStream::Tcp(stream) => stream.poll_read_ready(cx), + #[cfg(unix)] + TokioStream::Unix(stream) => stream.poll_read_ready(cx), + } + } + fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + match self { + TokioStream::Tcp(stream) => stream.poll_write_ready(cx), + #[cfg(unix)] + TokioStream::Unix(stream) => stream.poll_write_ready(cx), + } + } + fn try_read(&self, buf: &mut [u8]) -> std::io::Result { + match self { + TokioStream::Tcp(stream) => stream.try_read(buf), + #[cfg(unix)] + TokioStream::Unix(stream) => stream.try_read(buf), + } + } + fn try_write(&self, buf: &[u8]) -> std::io::Result { + match self { + TokioStream::Tcp(stream) => stream.try_write(buf), + #[cfg(unix)] + TokioStream::Unix(stream) => stream.try_write(buf), + } + } + fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> { + match self { + TokioStream::Tcp(stream) => stream.shutdown(how), + #[cfg(unix)] + TokioStream::Unix(stream) => stream.shutdown(how), + } + } + + fn into_std(self) -> Option> { + unimplemented!() + } + + fn downcast(self) -> Result { + match self { + TokioStream::Tcp(stream) => UnderlyingStream::downcast(stream).map_err(Self::Tcp), + #[cfg(unix)] + TokioStream::Unix(stream) => UnderlyingStream::downcast(stream).map_err(Self::Unix), + } } } diff --git a/gel-stream/src/common/tokio_stream.rs b/gel-stream/src/common/tokio_stream.rs index 272a2846..bf447b4b 100644 --- a/gel-stream/src/common/tokio_stream.rs +++ b/gel-stream/src/common/tokio_stream.rs @@ -144,9 +144,12 @@ impl futures::Stream for TokioListenerStream { } /// Represents a connected Tokio stream, either TCP or Unix -#[derive(derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)] +#[derive( + derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor, derive_more::Debug, +)] pub enum TokioStream { /// TCP stream + #[debug("{_0:?}")] Tcp( #[read] #[write] @@ -155,6 +158,7 @@ pub enum TokioStream { ), /// Unix stream (only available on Unix systems) #[cfg(unix)] + #[debug("{_0:?}")] Unix( #[read] #[write] diff --git a/gel-stream/tests/tls_unix.rs b/gel-stream/tests/tls_unix.rs new file mode 100644 index 00000000..a3afd775 --- /dev/null +++ b/gel-stream/tests/tls_unix.rs @@ -0,0 +1,166 @@ +#![cfg(unix)] + +use futures::StreamExt; +use gel_stream::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +fn load_test_cert() -> rustls_pki_types::CertificateDer<'static> { + gel_stream::test_keys::binary::SERVER_CERT.clone() +} + +fn load_test_key() -> rustls_pki_types::PrivateKeyDer<'static> { + gel_stream::test_keys::binary::SERVER_KEY.clone_key() +} + +fn tls_server_parameters(alpn: TlsAlpn) -> TlsServerParameterProvider { + TlsServerParameterProvider::new(TlsServerParameters { + server_certificate: TlsKey::new(load_test_key(), load_test_cert()), + client_cert_verify: TlsClientCertVerify::Ignore, + min_protocol_version: None, + max_protocol_version: None, + alpn, + }) +} + +async fn spawn_unix_tls_server( + server_alpn: TlsAlpn, + expected_alpn: Option<&str>, +) -> Result< + ( + std::path::PathBuf, + tokio::task::JoinHandle>, + ), + ConnectionError, +> { + let tempdir = tempfile::tempdir().unwrap(); + let path = tempdir.path().join("gel-stream-tls-test"); + + let unix_addr = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(&path)?); + + let mut acceptor = Acceptor::new_tls(unix_addr, tls_server_parameters(server_alpn)) + .bind_explicit::() + .await?; + + let expected_alpn = expected_alpn.map(|alpn| alpn.as_bytes().to_vec()); + let path_clone = path.clone(); + let accept_task = tokio::spawn(async move { + // Keep tempdir alive for the duration of the test + let _tempdir = tempdir; + let mut connection = acceptor.next().await.unwrap()?; + let handshake = connection + .handshake() + .unwrap_or_else(|| panic!("handshake was not available on {connection:?}")); + assert!(handshake.version.is_some()); + assert_eq!( + handshake.alpn.as_ref().map(|b| b.as_ref().to_vec()), + expected_alpn + ); + let mut buf = String::new(); + connection.read_to_string(&mut buf).await.unwrap(); + assert_eq!(buf, "Hello, Unix TLS!"); + connection.shutdown().await?; + Ok::<_, ConnectionError>(()) + }); + Ok((path_clone, accept_task)) +} + +macro_rules! unix_tls_test ( + ( + $( + $(#[ $attr:meta ])* + async fn $name:ident() -> Result<(), ConnectionError> $body:block + )* + ) => { + mod rustls_openssl { + use super::*; + $( + $(#[ $attr ])* + async fn $name() -> Result<(), ConnectionError> { + async fn test_inner() -> Result<(), ConnectionError> { + $body + } + test_inner::().await + } + )* + } + + mod openssl_rustls { + use super::*; + $( + $(#[ $attr ])* + async fn $name() -> Result<(), ConnectionError> { + async fn test_inner() -> Result<(), ConnectionError> { + $body + } + test_inner::().await + } + )* + } + } +); + +unix_tls_test! { + /// Basic Unix TLS test with ALPN - client connects to server over Unix socket with TLS + #[tokio::test] + #[ntest::timeout(30_000)] + async fn test_unix_tls_basic() -> Result<(), ConnectionError> { + let (path, accept_task) = spawn_unix_tls_server::( + TlsAlpn::new_str(&["nope", "accepted"]), + Some("accepted"), + ) + .await?; + + let connect_task = tokio::spawn(async move { + let name = TargetName::new_unix_path(path)?; + let target = Target::new_tls( + name, + TlsParameters { + server_cert_verify: TlsServerCertVerify::Insecure, + alpn: TlsAlpn::new_str(&["accepted", "fake"]), + ..Default::default() + }, + ); + let mut stm = Connector::::new_explicit(target).unwrap().connect().await.unwrap(); + stm.write_all(b"Hello, Unix TLS!").await.unwrap(); + stm.shutdown().await?; + Ok::<_, std::io::Error>(()) + }); + + accept_task.await.unwrap().unwrap(); + connect_task.await.unwrap().unwrap(); + + Ok(()) + } + + /// Unix TLS test with custom certificate verification + #[tokio::test] + #[ntest::timeout(30_000)] + async fn test_unix_tls_custom_cert() -> Result<(), ConnectionError> { + let (path, accept_task) = spawn_unix_tls_server::( + TlsAlpn::new_str(&["unix-tls"]), + Some("unix-tls"), + ) + .await?; + + let connect_task = tokio::spawn(async move { + let name = TargetName::new_unix_path(path)?; + let target = Target::new_tls( + name, + TlsParameters { + server_cert_verify: TlsServerCertVerify::Insecure, + alpn: TlsAlpn::new_str(&["unix-tls"]), + ..Default::default() + }, + ); + let mut stm = Connector::::new_explicit(target).unwrap().connect().await.unwrap(); + stm.write_all(b"Hello, Unix TLS!").await.unwrap(); + stm.shutdown().await?; + Ok::<_, std::io::Error>(()) + }); + + accept_task.await.unwrap().unwrap(); + connect_task.await.unwrap().unwrap(); + + Ok(()) + } +}