Skip to content

Commit c52b2ad

Browse files
committed
feat: add TLS support
Closes #14.
1 parent 45db940 commit c52b2ad

File tree

8 files changed

+453
-76
lines changed

8 files changed

+453
-76
lines changed

Cargo.toml

+6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ hashbrown = "0.12.0"
2929
hashlink = "0.8.0"
3030
either = "1.9.0"
3131
uuid = { version = "1.4.1", features = ["v4"] }
32+
rustls = "0.23.2"
33+
rustls-pemfile = "2"
34+
webpki-roots = "0.26.1"
35+
derive-where = "1.2.7"
3236

3337
[dev-dependencies]
3438
test-log = "0.2.12"
@@ -40,3 +44,5 @@ testcontainers = { git = "https://github.com/kezhuw/testcontainers-rs.git", bran
4044
assertor = "0.0.2"
4145
assert_matches = "1.5.0"
4246
tempfile = "3.6.0"
47+
maplit = "1.0.2"
48+
rcgen = { version = "0.12.1", features = ["default", "x509-parser"] }

src/client/mod.rs

+72-5
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@ use std::borrow::Cow;
44
use std::fmt::Write as _;
55
use std::future::Future;
66
use std::mem::ManuallyDrop;
7+
use std::sync::Arc;
78
use std::time::Duration;
89

910
use const_format::formatcp;
1011
use either::{Either, Left, Right};
1112
use ignore_result::Ignore;
13+
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
14+
use rustls::{ClientConfig, RootCertStore};
1215
use thiserror::Error;
1316
use tokio::sync::{mpsc, watch};
1417

@@ -1528,6 +1531,9 @@ pub(crate) struct Version(u32, u32, u32);
15281531
/// Builder for [Client] with more options than [Client::connect].
15291532
#[derive(Clone, Debug)]
15301533
pub struct ClientBuilder {
1534+
tls: bool,
1535+
trusted_certs: RootCertStore,
1536+
client_certs: Option<(Vec<CertificateDer<'static>>, Arc<PrivateKeyDer<'static>>)>,
15311537
authes: Vec<AuthPacket>,
15321538
version: Version,
15331539
session: Option<(SessionId, Vec<u8>)>,
@@ -1540,6 +1546,9 @@ pub struct ClientBuilder {
15401546
impl ClientBuilder {
15411547
fn new() -> Self {
15421548
Self {
1549+
tls: false,
1550+
trusted_certs: RootCertStore::empty(),
1551+
client_certs: None,
15431552
authes: Default::default(),
15441553
version: Version(u32::MAX, u32::MAX, u32::MAX),
15451554
session: None,
@@ -1584,6 +1593,43 @@ impl ClientBuilder {
15841593
self
15851594
}
15861595

1596+
/// Assumes tls for server in connection string if no protocol specified individually.
1597+
/// See [Self::connect] for syntax to specify protocol individually.
1598+
pub fn assume_tls(&mut self) -> &mut Self {
1599+
self.tls = true;
1600+
self
1601+
}
1602+
1603+
/// Trusts certificates signed by given ca certificates.
1604+
pub fn trust_ca_pem_certs(&mut self, certs: &str) -> Result<&mut Self> {
1605+
for r in rustls_pemfile::certs(&mut certs.as_bytes()) {
1606+
let cert = match r {
1607+
Ok(cert) => cert,
1608+
Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)),
1609+
};
1610+
if let Err(err) = self.trusted_certs.add(cert) {
1611+
return Err(Error::other(format!("fail to add cert {}", err), err));
1612+
}
1613+
}
1614+
Ok(self)
1615+
}
1616+
1617+
/// Identifies client itself to server with given cert chain and private key.
1618+
pub fn use_client_pem_cert(&mut self, cert: &str, key: &str) -> Result<&mut Self> {
1619+
let r: std::result::Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect();
1620+
let certs = match r {
1621+
Err(err) => return Err(Error::other(format!("fail to read cert {}", err), err)),
1622+
Ok(certs) => certs,
1623+
};
1624+
let key = match rustls_pemfile::private_key(&mut key.as_bytes()) {
1625+
Err(err) => return Err(Error::other(format!("fail to read client private key {err}"), err)),
1626+
Ok(None) => return Err(Error::BadArguments(&"no client private key")),
1627+
Ok(Some(key)) => key,
1628+
};
1629+
self.client_certs = Some((certs, Arc::new(key)));
1630+
Ok(self)
1631+
}
1632+
15871633
/// Specifies client assumed server version of ZooKeeper cluster.
15881634
///
15891635
/// Client will issue server compatible protocol to avoid [Error::Unimplemented] for some
@@ -1606,13 +1652,17 @@ impl ClientBuilder {
16061652

16071653
/// Connects to ZooKeeper cluster.
16081654
///
1655+
/// Parameter `cluster` specifies connection string to ZooKeeper cluster. It has same syntax as
1656+
/// Java client except that you can specifies protocol for server individually. For example,
1657+
/// `tcp://server1,tcp+tls://server2:port,server3`. This claims that `server1` uses plaintext
1658+
/// protocol, `server2` uses tls encrypted protocol while `server3` uses tls if
1659+
/// [Self::assume_tls] is specified or plaintext otherwise.
1660+
///
16091661
/// # Notable errors
16101662
/// * [Error::NoHosts] if no host is available
16111663
/// * [Error::SessionExpired] if specified session expired
16121664
pub async fn connect(&mut self, cluster: &str) -> Result<Client> {
1613-
let (hosts, chroot) = util::parse_connect_string(cluster)?;
1614-
let mut buf = Vec::with_capacity(4096);
1615-
let mut connecting_depot = Depot::for_connecting();
1665+
let (hosts, chroot) = util::parse_connect_string(cluster, self.tls)?;
16161666
if let Some((id, password)) = &self.session {
16171667
if id.0 == 0 {
16181668
return Err(Error::BadArguments(&"session id must not be 0"));
@@ -1628,22 +1678,39 @@ impl ClientBuilder {
16281678
} else if self.connection_timeout < Duration::ZERO {
16291679
return Err(Error::BadArguments(&"connection timeout must not be negative"));
16301680
}
1681+
self.trusted_certs.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
1682+
let tls_config = if let Some((certs, private_key)) = self.client_certs.take() {
1683+
match ClientConfig::builder()
1684+
.with_root_certificates(std::mem::replace(&mut self.trusted_certs, RootCertStore::empty()))
1685+
.with_client_auth_cert(certs, Arc::try_unwrap(private_key).unwrap_or_else(|k| k.clone_key()))
1686+
{
1687+
Ok(config) => config,
1688+
Err(err) => return Err(Error::other(format!("invalid client private key {err}"), err)),
1689+
}
1690+
} else {
1691+
ClientConfig::builder()
1692+
.with_root_certificates(std::mem::replace(&mut self.trusted_certs, RootCertStore::empty()))
1693+
.with_no_client_auth()
1694+
};
16311695
let (mut session, state_receiver) = Session::new(
16321696
self.session.take(),
16331697
&self.authes,
16341698
self.readonly,
16351699
self.detached,
1700+
tls_config,
16361701
self.session_timeout,
16371702
self.connection_timeout,
16381703
);
16391704
let mut hosts_iter = hosts.iter().copied();
1640-
let sock = session.start(&mut hosts_iter, &mut buf, &mut connecting_depot).await?;
1705+
let mut buf = Vec::with_capacity(4096);
1706+
let mut connecting_depot = Depot::for_connecting();
1707+
let conn = session.start(&mut hosts_iter, &mut buf, &mut connecting_depot).await?;
16411708
let (sender, receiver) = mpsc::unbounded_channel();
16421709
let servers = hosts.into_iter().map(|addr| addr.to_value()).collect();
16431710
let session_info = (session.session_id, session.session_password.clone());
16441711
let session_timeout = session.session_timeout;
16451712
tokio::spawn(async move {
1646-
session.serve(servers, sock, buf, connecting_depot, receiver).await;
1713+
session.serve(servers, conn, buf, connecting_depot, receiver).await;
16471714
});
16481715
let client =
16491716
Client::new(chroot.to_owned(), self.version, session_info, session_timeout, sender, state_receiver);

src/error.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::sync::Arc;
2+
3+
use derive_where::derive_where;
14
use static_assertions::assert_impl_all;
25
use thiserror::Error;
36

@@ -82,6 +85,18 @@ pub enum Error {
8285

8386
#[error("runtime condition mismatch")]
8487
RuntimeInconsistent,
88+
89+
#[error(transparent)]
90+
Other(OtherError),
91+
}
92+
93+
#[derive(Error, Clone, Debug)]
94+
#[derive_where(Eq, PartialEq)]
95+
#[error("{message}")]
96+
pub struct OtherError {
97+
message: Arc<String>,
98+
#[derive_where(skip(EqHashOrd))]
99+
source: Option<Arc<dyn std::error::Error + Send + Sync + 'static>>,
85100
}
86101

87102
impl Error {
@@ -111,6 +126,10 @@ impl Error {
111126
_ => false,
112127
}
113128
}
129+
130+
pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
131+
Self::Other(OtherError { message: Arc::new(message.into()), source: Some(Arc::new(source)) })
132+
}
114133
}
115134

116135
assert_impl_all!(Error: Send, Sync);

src/session/connection.rs

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
use std::io::{self, Read, Write};
2+
use std::sync::Arc;
3+
4+
use rustls::pki_types::ServerName;
5+
use rustls::{ClientConfig, ClientConnection};
6+
use tokio::net::TcpStream;
7+
8+
use crate::error::Error;
9+
10+
pub struct Connection {
11+
tls: Option<ClientConnection>,
12+
stream: TcpStream,
13+
}
14+
15+
struct WrappingStream<'a> {
16+
stream: &'a TcpStream,
17+
}
18+
19+
impl io::Read for WrappingStream<'_> {
20+
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
21+
self.stream.try_read_buf(&mut buf)
22+
}
23+
}
24+
25+
impl io::Write for WrappingStream<'_> {
26+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
27+
self.stream.try_write(buf)
28+
}
29+
30+
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
31+
self.stream.try_write_vectored(bufs)
32+
}
33+
34+
fn flush(&mut self) -> io::Result<()> {
35+
Ok(())
36+
}
37+
}
38+
39+
impl Connection {
40+
pub fn new_raw(stream: TcpStream) -> Self {
41+
Self { tls: None, stream }
42+
}
43+
44+
pub fn new_tls(host: &str, config: Arc<ClientConfig>, stream: TcpStream) -> Result<Self, Error> {
45+
let name = match ServerName::try_from(host) {
46+
Err(_) => return Err(Error::BadArguments(&"invalid server dns name")),
47+
Ok(name) => name.to_owned(),
48+
};
49+
let client = match ClientConnection::new(config, name) {
50+
Err(err) => return Err(Error::other(format!("fail to create tls client for host({host}): {err}"), err)),
51+
Ok(client) => client,
52+
};
53+
Ok(Self { tls: Some(client), stream })
54+
}
55+
56+
pub fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
57+
let Some(client) = self.tls.as_mut() else {
58+
return self.stream.try_write_vectored(bufs);
59+
};
60+
let n = client.writer().write_vectored(bufs)?;
61+
let mut stream = WrappingStream { stream: &self.stream };
62+
client.write_tls(&mut stream)?;
63+
Ok(n)
64+
}
65+
66+
pub fn read_buf(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
67+
let Some(client) = self.tls.as_mut() else {
68+
return self.stream.try_read_buf(buf);
69+
};
70+
let mut stream = WrappingStream { stream: &self.stream };
71+
let mut read_bytes = 0;
72+
loop {
73+
match client.read_tls(&mut stream) {
74+
// We may have plaintext to return though tcp stream has been closed.
75+
// If not, read_bytes should be zero.
76+
Ok(0) => break,
77+
Ok(_) => {},
78+
Err(err) => match err.kind() {
79+
// backpressure: tls buffer is full, let's process_new_packets.
80+
io::ErrorKind::Other => {},
81+
io::ErrorKind::WouldBlock if read_bytes == 0 => {
82+
return Err(err);
83+
},
84+
_ => break,
85+
},
86+
}
87+
let state = client.process_new_packets().map_err(io::Error::other)?;
88+
let n = state.plaintext_bytes_to_read();
89+
buf.reserve(n);
90+
let slice = unsafe { &mut std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len() + n)[buf.len()..] };
91+
client.reader().read_exact(slice).unwrap();
92+
unsafe { buf.set_len(buf.len() + n) };
93+
read_bytes += n;
94+
}
95+
Ok(read_bytes)
96+
}
97+
98+
pub async fn readable(&self) -> io::Result<()> {
99+
let Some(client) = self.tls.as_ref() else {
100+
return self.stream.readable().await;
101+
};
102+
if client.wants_read() {
103+
self.stream.readable().await
104+
} else {
105+
// plaintext data are available for read
106+
std::future::ready(Ok(())).await
107+
}
108+
}
109+
110+
pub async fn writable(&self) -> io::Result<()> {
111+
self.stream.writable().await
112+
}
113+
114+
pub fn wants_write(&self) -> bool {
115+
self.tls.as_ref().map(|tls| tls.wants_write()).unwrap_or(false)
116+
}
117+
118+
pub fn flush(&mut self) -> io::Result<()> {
119+
let Some(client) = self.tls.as_mut() else {
120+
return Ok(());
121+
};
122+
let mut stream = WrappingStream { stream: &self.stream };
123+
while client.wants_write() {
124+
client.write_tls(&mut stream)?;
125+
}
126+
Ok(())
127+
}
128+
}

src/session/depot.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::io::{self, IoSlice};
33

44
use hashbrown::HashMap;
55
use strum::IntoEnumIterator;
6-
use tokio::net::TcpStream;
76

7+
use super::connection::Connection;
88
use super::request::{MarshalledRequest, OpStat, Operation, SessionOperation, StateResponser};
99
use super::types::WatchMode;
1010
use super::xid::Xid;
@@ -205,8 +205,18 @@ impl Depot {
205205
.any(|mode| self.watching_paths.contains_key(&(path, mode)))
206206
}
207207

208-
pub fn write_operations(&mut self, sock: &TcpStream, session_id: SessionId) -> Result<(), Error> {
209-
let result = sock.try_write_vectored(self.writing_slices.as_slice());
208+
pub fn write_operations(&mut self, conn: &mut Connection, session_id: SessionId) -> Result<(), Error> {
209+
if !self.has_pending_writes() {
210+
if let Err(err) = conn.flush() {
211+
if err.kind() == io::ErrorKind::WouldBlock {
212+
return Ok(());
213+
}
214+
log::debug!("ZooKeeper session {} write failed {}", session_id, err);
215+
return Err(Error::ConnectionLoss);
216+
}
217+
return Ok(());
218+
}
219+
let result = conn.write_vectored(self.writing_slices.as_slice());
210220
let mut written_bytes = match result {
211221
Err(err) => {
212222
if err.kind() == io::ErrorKind::WouldBlock {

0 commit comments

Comments
 (0)