diff --git a/Cargo.toml b/Cargo.toml index 2fe6fe8..bd2b3e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ webpki-roots = "0.22" rustls-native-certs = "0.6" x509-parser = "0.12" zeroize = { version = "1", features = ["zeroize_derive"] } +derive-where = "1.0.0-rc.1" [dev-dependencies] anyhow = "1" diff --git a/src/lib.rs b/src/lib.rs index 67dc903..15f1183 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,5 +75,8 @@ pub mod dangerous { pub mod error; mod quic; -pub use quic::{Builder, Connecting, Connection, Endpoint, Incoming, Receiver, Sender, Store}; +pub use quic::{ + Builder, Connecting, Connection, Endpoint, Incoming, Receiver, Sender, ServerCertificateConfig, + Store, +}; pub use x509::{Certificate, CertificateChain, KeyPair, PrivateKey}; diff --git a/src/quic/endpoint/builder/mod.rs b/src/quic/endpoint/builder/mod.rs index 3810286..3915154 100644 --- a/src/quic/endpoint/builder/mod.rs +++ b/src/quic/endpoint/builder/mod.rs @@ -9,8 +9,9 @@ use std::{ }; pub(super) use config::Config; +use derive_where::DeriveWhere; use rustls::{ - server::{ClientCertVerified, ClientCertVerifier}, + server::{ClientCertVerified, ClientCertVerifier, ResolvesServerCert}, DistinguishedNames, }; use serde::{Deserialize, Serialize}; @@ -38,7 +39,7 @@ pub struct Builder { /// Custom root [`Certificate`]s. root_certificates: Vec, /// Server certificate key-pair. - server_key_pair: Option, + server_certificate_config: Option, /// Client certificate key-pair. client_key_pair: Option, /// [`Store`] option. @@ -47,6 +48,40 @@ pub struct Builder { config: Config, } +/// A server certificate configuration. +#[derive(Clone, DeriveWhere)] +#[derive_where(Debug)] +pub enum ServerCertificateConfig { + /// All TLS connections should use a single certificate. + KeyPair(KeyPair), + /// Resolve TLS connections' certificates dynamically. + #[derive_where(skip_inner)] + Resolver(Arc), +} + +#[allow(clippy::vtable_address_comparisons)] +impl PartialEq for ServerCertificateConfig { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::KeyPair(l0), Self::KeyPair(r0)) => l0 == r0, + (Self::Resolver(l0), Self::Resolver(r0)) => Arc::ptr_eq(l0, r0), + _ => false, + } + } +} + +impl From for ServerCertificateConfig { + fn from(key_pair: KeyPair) -> Self { + Self::KeyPair(key_pair) + } +} + +impl From> for ServerCertificateConfig { + fn from(resolver: Arc) -> Self { + Self::Resolver(resolver) + } +} + impl Default for Builder { fn default() -> Self { Self::new() @@ -74,7 +109,7 @@ impl Builder { #[cfg(feature = "test")] address: ([0, 0, 0, 0, 0, 0, 0, 1], 0).into(), root_certificates: Vec::new(), - server_key_pair: None, + server_certificate_config: None, client_key_pair: None, store: Store::Embedded, config, @@ -135,29 +170,29 @@ impl Builder { /// use fabruic::{Builder, KeyPair}; /// /// let mut builder = Builder::new(); - /// builder.set_server_key_pair(Some(KeyPair::new_self_signed("test"))); + /// builder.set_server_certificate_config(Some(KeyPair::new_self_signed("test"))); /// ``` - pub fn set_server_key_pair(&mut self, key_pair: Option) { - self.server_key_pair = key_pair; + pub fn set_server_certificate_config(&mut self, key_pair: Option>) { + self.server_certificate_config = key_pair.map(Into::into); } /// Returns the server certificate [`KeyPair`]. /// - /// See [`set_server_key_pair`](Self::set_server_key_pair). + /// See [`set_server_certificate_config`](Self::set_server_certificate_config). /// /// # Examples /// ``` - /// use fabruic::{Builder, KeyPair}; + /// use fabruic::{Builder, KeyPair, ServerCertificateConfig}; /// /// let mut builder = Builder::new(); /// /// let key_pair = KeyPair::new_self_signed("test"); - /// builder.set_server_key_pair(Some(key_pair.clone())); - /// assert_eq!(builder.server_key_pair(), &Some(key_pair)) + /// builder.set_server_certificate_config(Some(key_pair.clone())); + /// assert_eq!(builder.server_certificate_config(), &Some(ServerCertificateConfig::KeyPair(key_pair))) /// ``` #[must_use] - pub const fn server_key_pair(&self) -> &Option { - &self.server_key_pair + pub const fn server_certificate_config(&self) -> &Option { + &self.server_certificate_config } /// Set a client certificate [`KeyPair`], use [`None`] to @@ -505,37 +540,47 @@ impl Builder { false, ) { Ok(client) => client, - Err(error) => + Err(error) => { return Err(error::Builder { error: error.into(), builder: self, - }), + }) + } }; // build server only if we have a key-pair - let server = self.server_key_pair.as_ref().map(|key_pair| { - let mut crypto = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .expect("failed to configure correct protocol") - .with_client_cert_verifier(Arc::new(ClientVerifier)) - .with_single_cert( - key_pair.certificate_chain().clone().into_rustls(), - key_pair.private_key().clone().into_rustls(), - ) - .expect("`CertificateChain` couldn't be verified"); - - // set protocols - crypto.alpn_protocols = self.config.protocols().to_vec(); - - let mut server = quinn::ServerConfig::with_crypto(Arc::new(crypto)); - - // set transport - server.transport = Arc::new(self.config.transport()); - - server - }); + let server = self + .server_certificate_config + .as_ref() + .map(|certificate_config| { + let crypto = rustls::ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .expect("failed to configure correct protocol") + .with_client_cert_verifier(Arc::new(ClientVerifier)); + let mut crypto = match certificate_config { + ServerCertificateConfig::KeyPair(key_pair) => crypto + .with_single_cert( + key_pair.certificate_chain().clone().into_rustls(), + key_pair.private_key().clone().into_rustls(), + ) + .expect("`CertificateChain` couldn't be verified"), + ServerCertificateConfig::Resolver(resolver) => { + crypto.with_cert_resolver(Arc::clone(resolver)) + } + }; + + // set protocols + crypto.alpn_protocols = self.config.protocols().to_vec(); + + let mut server = quinn::ServerConfig::with_crypto(Arc::new(crypto)); + + // set transport + server.transport = Arc::new(self.config.transport()); + + server + }); Endpoint::new(self.address, client, server, self.config.clone()) } { @@ -702,7 +747,7 @@ mod test { // build server let mut builder = Builder::new(); - builder.set_server_key_pair(Some(key_pair.clone())); + builder.set_server_certificate_config(Some(key_pair.clone())); let mut server = builder.build()?; // test connection to server @@ -734,15 +779,16 @@ mod test { // build client let mut builder = Builder::new(); - Dangerous::set_root_certificates(&mut builder, [server_key_pair - .end_entity_certificate() - .clone()]); + Dangerous::set_root_certificates( + &mut builder, + [server_key_pair.end_entity_certificate().clone()], + ); builder.set_client_key_pair(Some(client_key_pair.clone())); let client = builder.build()?; // build server let mut builder = Builder::new(); - builder.set_server_key_pair(Some(server_key_pair)); + builder.set_server_certificate_config(Some(server_key_pair)); let mut server = builder.build()?; // test connection to server @@ -799,7 +845,7 @@ mod test { // build server let mut builder = Builder::new(); - builder.set_server_key_pair(Some(key_pair.clone())); + builder.set_server_certificate_config(Some(key_pair.clone())); builder.set_protocols(protocols.clone()); let mut server = builder.build()?; @@ -855,7 +901,7 @@ mod test { // build server let mut builder = Builder::new(); - builder.set_server_key_pair(Some(key_pair.clone())); + builder.set_server_certificate_config(Some(key_pair.clone())); builder.set_protocols([b"test2".to_vec()]); let mut server = builder.build()?; diff --git a/src/quic/endpoint/mod.rs b/src/quic/endpoint/mod.rs index ce4e758..6059496 100644 --- a/src/quic/endpoint/mod.rs +++ b/src/quic/endpoint/mod.rs @@ -14,7 +14,7 @@ use std::{ use async_trait::async_trait; use builder::Config; -pub use builder::{Builder, Dangerous as BuilderDangerous, Store}; +pub use builder::{Builder, Dangerous as BuilderDangerous, ServerCertificateConfig, Store}; use flume::{r#async::RecvStream, Sender}; use futures_channel::oneshot::Receiver; use futures_util::{ @@ -171,7 +171,7 @@ impl Endpoint { // while testing always use the default loopback address #[cfg(feature = "test")] builder.set_address(([0, 0, 0, 0, 0, 0, 0, 1], port).into()); - builder.set_server_key_pair(Some(key_pair)); + builder.set_server_certificate_config(Some(key_pair)); builder .build() @@ -467,7 +467,7 @@ impl Endpoint { /// finish first. This will always return [`error::AlreadyClosed`] if the /// [`Endpoint`] wasn't started with a listener. /// - /// See [`Builder::set_server_key_pair`]. + /// See [`Builder::set_server_certificate_config`]. /// /// # Errors /// [`error::AlreadyClosed`] if it was already closed. diff --git a/src/quic/mod.rs b/src/quic/mod.rs index e5ed449..c134d6c 100644 --- a/src/quic/mod.rs +++ b/src/quic/mod.rs @@ -5,5 +5,7 @@ mod endpoint; mod task; pub use connection::{Connecting, Connection, Incoming, Receiver, Sender}; -pub use endpoint::{Builder, BuilderDangerous, Dangerous, Endpoint, Store}; +pub use endpoint::{ + Builder, BuilderDangerous, Dangerous, Endpoint, ServerCertificateConfig, Store, +}; use task::Task;