Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gel-stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ hickory-resolver = { version = "0.25.2", optional = true, default-features = fal
# feature = "rustls"
# We rely on certain aspects of these crates. Use caution when upgrading.
rustls = { version = ">= 0.23.25", optional = true, default-features = false, features = ["ring", "logging", "std", "tls12"] }
rustls-tokio-stream = { version = "0.6.0", optional = true }
rustls-tokio-stream = { version = "0.8.0", optional = true }
rustls-platform-verifier = { version = "0.5.1", optional = true }
webpki = { version = "0.22", optional = true }
webpki-roots = { version = "1", optional = true }
Expand Down
3 changes: 0 additions & 3 deletions gel-stream/src/common/openssl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,6 @@ impl TlsDriver for OpensslDriver {
let stream = stream
.downcast::<TokioStream>()
.map_err(|_| crate::SslError::SslUnsupported)?;
let TokioStream::Tcp(stream) = stream else {
return Err(crate::SslError::SslUnsupported);
};

let mut stream =
tokio_openssl::SslStream::new(params, Box::new(stream) as Box<dyn Stream + Send>)?;
Expand Down
108 changes: 89 additions & 19 deletions gel-stream/src/common/rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustls_pki_types::{
CertificateDer, CertificateRevocationListDer, DnsName, ServerName, UnixTime,
};
use rustls_platform_verifier::Verifier;
use rustls_tokio_stream::TlsStream;
use rustls_tokio_stream::{TlsStream, UnderlyingStream};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};

use super::tokio_stream::TokioStream;
Expand All @@ -23,12 +23,13 @@ use std::borrow::Cow;
use std::mem::MaybeUninit;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use std::task::{Context, Poll};

#[derive(Default)]
pub struct RustlsDriver;

impl TlsDriver for RustlsDriver {
type Stream = TlsStream;
type Stream = TlsStream<TokioStream>;
type ClientParams = ClientConnection;
type ServerParams = Arc<ServerConfig>;
const DRIVER_NAME: &'static str = "rustls";
Expand Down Expand Up @@ -134,9 +135,6 @@ impl TlsDriver for RustlsDriver {
let stream = stream
.downcast::<TokioStream>()
.map_err(|_| crate::SslError::SslUnsupported)?;
let TokioStream::Tcp(stream) = stream else {
return Err(crate::SslError::SslUnsupported);
};

let mut stream = TlsStream::new_client_side(stream, params, None);
match stream.handshake().await {
Expand Down Expand Up @@ -180,7 +178,7 @@ impl TlsDriver for RustlsDriver {
params: TlsServerParameterProvider,
stream: S,
) -> Result<(Self::Stream, TlsHandshake), SslError> {
let (stream, mut acceptor) = match stream.downcast::<RewindStream<TokioStream>>() {
let (mut stream, mut acceptor) = match stream.downcast::<RewindStream<TokioStream>>() {
Ok(stream) => {
let (stream, buffer) = stream.into_inner();
let mut acceptor = Acceptor::default();
Expand All @@ -195,10 +193,6 @@ impl TlsDriver for RustlsDriver {
}
};

let TokioStream::Tcp(mut stream) = stream else {
return Err(crate::SslError::SslUnsupported);
};

let mut buf = [MaybeUninit::uninit(); 1024];
let accepted = loop {
match acceptor.accept() {
Expand Down Expand Up @@ -638,19 +632,29 @@ impl ServerCertVerifier for ErrorFilteringVerifier {
}
}

impl LocalAddress for TlsStream {
impl LocalAddress for TlsStream<TokioStream> {
fn local_address(&self) -> std::io::Result<ResolvedTarget> {
self.local_addr().map(ResolvedTarget::from)
self.underlying_stream()
.ok_or(std::io::Error::new(
std::io::ErrorKind::Other,
"No underlying stream",
))?
.local_address()
}
}

impl RemoteAddress for TlsStream {
impl RemoteAddress for TlsStream<TokioStream> {
fn remote_address(&self) -> std::io::Result<ResolvedTarget> {
self.peer_addr().map(ResolvedTarget::from)
self.underlying_stream()
.ok_or(std::io::Error::new(
std::io::ErrorKind::Other,
"No underlying stream",
))?
.remote_address()
}
}

impl PeerCred for TlsStream {
impl PeerCred for TlsStream<TokioStream> {
#[cfg(all(unix, feature = "tokio"))]
fn peer_cred(&self) -> std::io::Result<tokio::net::unix::UCred> {
Err(std::io::Error::new(
Expand All @@ -660,20 +664,86 @@ impl PeerCred for TlsStream {
}
}

impl StreamMetadata for TlsStream {
impl StreamMetadata for TlsStream<TokioStream> {
fn transport(&self) -> Transport {
Transport::Tcp
}
}

impl AsHandle for TlsStream {
impl AsHandle for TlsStream<TokioStream> {
#[cfg(windows)]
fn as_handle(&self) -> std::os::windows::io::BorrowedSocket {
std::os::windows::io::AsSocket::as_socket(self.tcp_stream().unwrap())
std::os::windows::io::AsSocket::as_socket(self.underlying_stream().unwrap())
}

#[cfg(unix)]
fn as_fd(&self) -> std::os::fd::BorrowedFd {
std::os::fd::AsFd::as_fd(self.tcp_stream().unwrap())
std::os::fd::AsFd::as_fd(self.underlying_stream().unwrap())
}
}

impl UnderlyingStream for TokioStream {
type StdType = ();

async fn readable(&self) -> std::io::Result<()> {
match self {
TokioStream::Tcp(stream) => stream.readable().await,
#[cfg(unix)]
TokioStream::Unix(stream) => stream.readable().await,
}
}
async fn writable(&self) -> std::io::Result<()> {
match self {
TokioStream::Tcp(stream) => stream.writable().await,
#[cfg(unix)]
TokioStream::Unix(stream) => stream.writable().await,
}
}
fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self {
TokioStream::Tcp(stream) => stream.poll_read_ready(cx),
#[cfg(unix)]
TokioStream::Unix(stream) => stream.poll_read_ready(cx),
}
}
fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self {
TokioStream::Tcp(stream) => stream.poll_write_ready(cx),
#[cfg(unix)]
TokioStream::Unix(stream) => stream.poll_write_ready(cx),
}
}
fn try_read(&self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
TokioStream::Tcp(stream) => stream.try_read(buf),
#[cfg(unix)]
TokioStream::Unix(stream) => stream.try_read(buf),
}
}
fn try_write(&self, buf: &[u8]) -> std::io::Result<usize> {
match self {
TokioStream::Tcp(stream) => stream.try_write(buf),
#[cfg(unix)]
TokioStream::Unix(stream) => stream.try_write(buf),
}
}
fn shutdown(&self, how: std::net::Shutdown) -> std::io::Result<()> {
match self {
TokioStream::Tcp(stream) => stream.shutdown(how),
#[cfg(unix)]
TokioStream::Unix(stream) => stream.shutdown(how),
}
}

fn into_std(self) -> Option<std::io::Result<Self::StdType>> {
unimplemented!()
}

fn downcast<S: UnderlyingStream>(self) -> Result<S, Self> {
match self {
TokioStream::Tcp(stream) => UnderlyingStream::downcast(stream).map_err(Self::Tcp),
#[cfg(unix)]
TokioStream::Unix(stream) => UnderlyingStream::downcast(stream).map_err(Self::Unix),
}
}
}
6 changes: 5 additions & 1 deletion gel-stream/src/common/tokio_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ impl futures::Stream for TokioListenerStream {
}

/// Represents a connected Tokio stream, either TCP or Unix
#[derive(derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor)]
#[derive(
derive_io::AsyncRead, derive_io::AsyncWrite, derive_io::AsSocketDescriptor, derive_more::Debug,
)]
pub enum TokioStream {
/// TCP stream
#[debug("{_0:?}")]
Tcp(
#[read]
#[write]
Expand All @@ -155,6 +158,7 @@ pub enum TokioStream {
),
/// Unix stream (only available on Unix systems)
#[cfg(unix)]
#[debug("{_0:?}")]
Unix(
#[read]
#[write]
Expand Down
166 changes: 166 additions & 0 deletions gel-stream/tests/tls_unix.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#![cfg(unix)]

use futures::StreamExt;
use gel_stream::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

fn load_test_cert() -> rustls_pki_types::CertificateDer<'static> {
gel_stream::test_keys::binary::SERVER_CERT.clone()
}

fn load_test_key() -> rustls_pki_types::PrivateKeyDer<'static> {
gel_stream::test_keys::binary::SERVER_KEY.clone_key()
}

fn tls_server_parameters(alpn: TlsAlpn) -> TlsServerParameterProvider {
TlsServerParameterProvider::new(TlsServerParameters {
server_certificate: TlsKey::new(load_test_key(), load_test_cert()),
client_cert_verify: TlsClientCertVerify::Ignore,
min_protocol_version: None,
max_protocol_version: None,
alpn,
})
}

async fn spawn_unix_tls_server<S: TlsDriver>(
server_alpn: TlsAlpn,
expected_alpn: Option<&str>,
) -> Result<
(
std::path::PathBuf,
tokio::task::JoinHandle<Result<(), ConnectionError>>,
),
ConnectionError,
> {
let tempdir = tempfile::tempdir().unwrap();
let path = tempdir.path().join("gel-stream-tls-test");

let unix_addr = ResolvedTarget::from(std::os::unix::net::SocketAddr::from_pathname(&path)?);

let mut acceptor = Acceptor::new_tls(unix_addr, tls_server_parameters(server_alpn))
.bind_explicit::<S>()
.await?;

let expected_alpn = expected_alpn.map(|alpn| alpn.as_bytes().to_vec());
let path_clone = path.clone();
let accept_task = tokio::spawn(async move {
// Keep tempdir alive for the duration of the test
let _tempdir = tempdir;
let mut connection = acceptor.next().await.unwrap()?;
let handshake = connection
.handshake()
.unwrap_or_else(|| panic!("handshake was not available on {connection:?}"));
assert!(handshake.version.is_some());
assert_eq!(
handshake.alpn.as_ref().map(|b| b.as_ref().to_vec()),
expected_alpn
);
let mut buf = String::new();
connection.read_to_string(&mut buf).await.unwrap();
assert_eq!(buf, "Hello, Unix TLS!");
connection.shutdown().await?;
Ok::<_, ConnectionError>(())
});
Ok((path_clone, accept_task))
}

macro_rules! unix_tls_test (
(
$(
$(#[ $attr:meta ])*
async fn $name:ident<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> $body:block
)*
) => {
mod rustls_openssl {
use super::*;
$(
$(#[ $attr ])*
async fn $name() -> Result<(), ConnectionError> {
async fn test_inner<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> {
$body
}
test_inner::<RustlsDriver, OpensslDriver>().await
}
)*
}

mod openssl_rustls {
use super::*;
$(
$(#[ $attr ])*
async fn $name() -> Result<(), ConnectionError> {
async fn test_inner<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> {
$body
}
test_inner::<OpensslDriver, RustlsDriver>().await
}
)*
}
}
);

unix_tls_test! {
/// Basic Unix TLS test with ALPN - client connects to server over Unix socket with TLS
#[tokio::test]
#[ntest::timeout(30_000)]
async fn test_unix_tls_basic<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> {
let (path, accept_task) = spawn_unix_tls_server::<S>(
TlsAlpn::new_str(&["nope", "accepted"]),
Some("accepted"),
)
.await?;

let connect_task = tokio::spawn(async move {
let name = TargetName::new_unix_path(path)?;
let target = Target::new_tls(
name,
TlsParameters {
server_cert_verify: TlsServerCertVerify::Insecure,
alpn: TlsAlpn::new_str(&["accepted", "fake"]),
..Default::default()
},
);
let mut stm = Connector::<C>::new_explicit(target).unwrap().connect().await.unwrap();
stm.write_all(b"Hello, Unix TLS!").await.unwrap();
stm.shutdown().await?;
Ok::<_, std::io::Error>(())
});

accept_task.await.unwrap().unwrap();
connect_task.await.unwrap().unwrap();

Ok(())
}

/// Unix TLS test with custom certificate verification
#[tokio::test]
#[ntest::timeout(30_000)]
async fn test_unix_tls_custom_cert<C: TlsDriver, S: TlsDriver>() -> Result<(), ConnectionError> {
let (path, accept_task) = spawn_unix_tls_server::<S>(
TlsAlpn::new_str(&["unix-tls"]),
Some("unix-tls"),
)
.await?;

let connect_task = tokio::spawn(async move {
let name = TargetName::new_unix_path(path)?;
let target = Target::new_tls(
name,
TlsParameters {
server_cert_verify: TlsServerCertVerify::Insecure,
alpn: TlsAlpn::new_str(&["unix-tls"]),
..Default::default()
},
);
let mut stm = Connector::<C>::new_explicit(target).unwrap().connect().await.unwrap();
stm.write_all(b"Hello, Unix TLS!").await.unwrap();
stm.shutdown().await?;
Ok::<_, std::io::Error>(())
});

accept_task.await.unwrap().unwrap();
connect_task.await.unwrap().unwrap();

Ok(())
}
}
Loading