Skip to content

Commit 8c2056d

Browse files
committed
refactor: toggle TLS support with feature gate
1 parent b68c967 commit 8c2056d

File tree

7 files changed

+269
-109
lines changed

7 files changed

+269
-109
lines changed

Cargo.toml

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ rust-version = "1.65"
1616

1717
[features]
1818
default = []
19+
tls = ["rustls", "rustls-pemfile", "webpki-roots"]
1920
sasl = ["sasl-gssapi", "sasl-digest-md5"]
2021
sasl-digest-md5 = ["rsasl/unstable_custom_mechanism", "md5", "linkme", "hex"]
2122
sasl-gssapi = ["rsasl/gssapi"]
@@ -34,9 +35,9 @@ hashbrown = "0.12.0"
3435
hashlink = "0.8.0"
3536
either = "1.9.0"
3637
uuid = { version = "1.4.1", features = ["v4"] }
37-
rustls = "0.23.2"
38-
rustls-pemfile = "2"
39-
webpki-roots = "0.26.1"
38+
rustls = { version = "0.23.2", optional = true }
39+
rustls-pemfile = { version = "2", optional = true }
40+
webpki-roots = { version = "0.26.1", optional = true }
4041
derive-where = "1.2.7"
4142
tokio-rustls = "0.26.0"
4243
fastrand = "2.0.2"

src/client/mod.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use crate::record::{self, Record, StaticRecord};
4848
use crate::sasl::SaslOptions;
4949
use crate::session::StateReceiver;
5050
pub use crate::session::{EventType, SessionId, SessionInfo, SessionState, WatchedEvent};
51+
#[cfg(feature = "tls")]
5152
use crate::tls::TlsOptions;
5253
use crate::util;
5354

@@ -1538,6 +1539,7 @@ pub(crate) struct Version(u32, u32, u32);
15381539
/// A builder for [Client].
15391540
#[derive(Clone, Debug)]
15401541
pub struct Connector {
1542+
#[cfg(feature = "tls")]
15411543
tls: Option<TlsOptions>,
15421544
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
15431545
sasl: Option<SaslOptions>,
@@ -1555,6 +1557,7 @@ pub struct Connector {
15551557
impl Connector {
15561558
fn new() -> Self {
15571559
Self {
1560+
#[cfg(feature = "tls")]
15581561
tls: None,
15591562
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
15601563
sasl: None,
@@ -1624,6 +1627,7 @@ impl Connector {
16241627
}
16251628

16261629
/// Specifies tls options for connections to ZooKeeper.
1630+
#[cfg(feature = "tls")]
16271631
pub fn tls(&mut self, options: TlsOptions) -> &mut Self {
16281632
self.tls = Some(options);
16291633
self
@@ -1649,13 +1653,14 @@ impl Connector {
16491653
async fn connect_internally(&mut self, secure: bool, cluster: &str) -> Result<Client> {
16501654
let (endpoints, chroot) = endpoint::parse_connect_string(cluster, secure)?;
16511655
let builder = Session::builder()
1652-
.with_tls(self.tls.take())
16531656
.with_session(self.session.take())
16541657
.with_authes(&self.authes)
16551658
.with_readonly(self.readonly)
16561659
.with_detached(self.detached)
16571660
.with_session_timeout(self.session_timeout)
16581661
.with_connection_timeout(self.connection_timeout);
1662+
#[cfg(feature = "tls")]
1663+
let builder = builder.with_tls(self.tls.take());
16591664
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
16601665
let builder = builder.with_sasl(self.sasl.take());
16611666
let (mut session, state_receiver) = builder.build()?;
@@ -1685,6 +1690,7 @@ impl Connector {
16851690
///
16861691
/// Same to [Self::connect] except that `server1` will use tls encrypted protocol given
16871692
/// the connection string `server1,tcp://server2,tcp+tls://server3`.
1693+
#[cfg(feature = "tls")]
16881694
pub async fn secure_connect(&mut self, cluster: &str) -> Result<Client> {
16891695
self.connect_internally(true, cluster).await
16901696
}

src/error.rs

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ impl Error {
138138
Self::Other(OtherError { message: message.into(), source })
139139
}
140140

141+
#[allow(dead_code)]
141142
pub(crate) fn other(message: impl Into<String>, source: impl std::error::Error + Send + Sync + 'static) -> Self {
142143
Self::new_other(message.into(), Some(Arc::new(source)))
143144
}

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ mod record;
99
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
1010
mod sasl;
1111
mod session;
12+
#[cfg(feature = "tls")]
1213
mod tls;
1314
mod util;
1415

1516
pub use self::acl::{Acl, Acls, AuthId, AuthUser, Permission};
1617
pub use self::error::Error;
18+
#[cfg(feature = "tls")]
1719
pub use self::tls::TlsOptions;
1820
pub use crate::client::*;
1921
#[cfg(feature = "sasl-digest-md5")]

src/session/connection.rs

+75-12
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,47 @@
11
use std::io::{Error, ErrorKind, IoSlice, Result};
22
use std::pin::Pin;
33
use std::ptr;
4-
use std::sync::Arc;
54
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
65
use std::time::Duration;
76

87
use bytes::buf::BufMut;
98
use ignore_result::Ignore;
10-
use rustls::pki_types::ServerName;
11-
use rustls::ClientConfig;
129
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
1310
use tokio::net::TcpStream;
1411
use tokio::{select, time};
15-
use tokio_rustls::client::TlsStream;
16-
use tokio_rustls::TlsConnector;
1712
use tracing::{debug, trace};
1813

14+
#[cfg(feature = "tls")]
15+
mod tls {
16+
pub use std::sync::Arc;
17+
18+
pub use rustls::pki_types::ServerName;
19+
pub use rustls::ClientConfig;
20+
pub use tokio_rustls::client::TlsStream;
21+
pub use tokio_rustls::TlsConnector;
22+
}
23+
#[cfg(feature = "tls")]
24+
use tls::*;
25+
1926
use crate::deadline::Deadline;
2027
use crate::endpoint::{EndpointRef, IterableEndpoints};
2128

2229
const NOOP_VTABLE: RawWakerVTable =
2330
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
2431
const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE);
2532

33+
#[derive(Debug)]
2634
pub enum Connection {
27-
Tls(TlsStream<TcpStream>),
2835
Raw(TcpStream),
36+
#[cfg(feature = "tls")]
37+
Tls(TlsStream<TcpStream>),
2938
}
3039

3140
impl AsyncRead for Connection {
3241
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
3342
match self.get_mut() {
3443
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
44+
#[cfg(feature = "tls")]
3545
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
3646
}
3747
}
@@ -41,20 +51,23 @@ impl AsyncWrite for Connection {
4151
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
4252
match self.get_mut() {
4353
Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
54+
#[cfg(feature = "tls")]
4455
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
4556
}
4657
}
4758

4859
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
4960
match self.get_mut() {
5061
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
62+
#[cfg(feature = "tls")]
5163
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
5264
}
5365
}
5466

5567
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
5668
match self.get_mut() {
5769
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
70+
#[cfg(feature = "tls")]
5871
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
5972
}
6073
}
@@ -65,6 +78,7 @@ impl Connection {
6578
Self::Raw(stream)
6679
}
6780

81+
#[cfg(feature = "tls")]
6882
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
6983
Self::Tls(stream)
7084
}
@@ -97,6 +111,7 @@ impl Connection {
97111
pub async fn readable(&self) -> Result<()> {
98112
match self {
99113
Self::Raw(stream) => stream.readable().await,
114+
#[cfg(feature = "tls")]
100115
Self::Tls(stream) => {
101116
let (stream, session) = stream.get_ref();
102117
if session.wants_read() {
@@ -112,6 +127,7 @@ impl Connection {
112127
pub async fn writable(&self) -> Result<()> {
113128
match self {
114129
Self::Raw(stream) => stream.writable().await,
130+
#[cfg(feature = "tls")]
115131
Self::Tls(stream) => {
116132
let (stream, _session) = stream.get_ref();
117133
stream.writable().await
@@ -122,6 +138,7 @@ impl Connection {
122138
pub fn wants_write(&self) -> bool {
123139
match self {
124140
Self::Raw(_) => false,
141+
#[cfg(feature = "tls")]
125142
Self::Tls(stream) => {
126143
let (_stream, session) = stream.get_ref();
127144
session.wants_write()
@@ -160,13 +177,33 @@ impl Connection {
160177

161178
#[derive(Clone)]
162179
pub struct Connector {
163-
tls: TlsConnector,
180+
#[cfg(feature = "tls")]
181+
tls: Option<TlsConnector>,
164182
timeout: Duration,
165183
}
166184

167185
impl Connector {
168-
pub fn new(config: impl Into<Arc<ClientConfig>>) -> Self {
169-
Self { tls: TlsConnector::from(config.into()), timeout: Duration::from_secs(10) }
186+
#[cfg(feature = "tls")]
187+
#[allow(dead_code)]
188+
pub fn new() -> Self {
189+
Self { tls: None, timeout: Duration::from_secs(10) }
190+
}
191+
192+
#[cfg(not(feature = "tls"))]
193+
pub fn new() -> Self {
194+
Self { timeout: Duration::from_secs(10) }
195+
}
196+
197+
#[cfg(feature = "tls")]
198+
pub fn with_tls(config: ClientConfig) -> Self {
199+
Self { tls: Some(TlsConnector::from(Arc::new(config))), timeout: Duration::from_secs(10) }
200+
}
201+
202+
#[cfg(feature = "tls")]
203+
async fn connect_tls(&self, stream: TcpStream, host: &str) -> Result<Connection> {
204+
let domain = ServerName::try_from(host).unwrap().to_owned();
205+
let stream = self.tls.as_ref().unwrap().connect(domain, stream).await?;
206+
Ok(Connection::new_tls(stream))
170207
}
171208

172209
pub fn timeout(&self) -> Duration {
@@ -178,6 +215,14 @@ impl Connector {
178215
}
179216

180217
pub async fn connect(&self, endpoint: EndpointRef<'_>, deadline: &mut Deadline) -> Result<Connection> {
218+
if endpoint.tls {
219+
#[cfg(feature = "tls")]
220+
if self.tls.is_none() {
221+
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
222+
}
223+
#[cfg(not(feature = "tls"))]
224+
return Err(Error::new(ErrorKind::Unsupported, "tls not supported"));
225+
}
181226
select! {
182227
_ = unsafe { Pin::new_unchecked(deadline) } => Err(Error::new(ErrorKind::TimedOut, "deadline exceed")),
183228
_ = time::sleep(self.timeout) => Err(Error::new(ErrorKind::TimedOut, format!("connection timeout{:?} exceed", self.timeout))),
@@ -186,9 +231,10 @@ impl Connector {
186231
Err(err) => Err(err),
187232
Ok(sock) => {
188233
let connection = if endpoint.tls {
189-
let domain = ServerName::try_from(endpoint.host).unwrap().to_owned();
190-
let stream = self.tls.connect(domain, sock).await?;
191-
Connection::new_tls(stream)
234+
#[cfg(not(feature = "tls"))]
235+
unreachable!("tls not supported");
236+
#[cfg(feature = "tls")]
237+
self.connect_tls(sock, endpoint.host).await?
192238
} else {
193239
Connection::new_raw(sock)
194240
};
@@ -231,3 +277,20 @@ impl Connector {
231277
None
232278
}
233279
}
280+
281+
#[cfg(test)]
282+
mod tests {
283+
use std::io::ErrorKind;
284+
285+
use super::Connector;
286+
use crate::deadline::Deadline;
287+
use crate::endpoint::EndpointRef;
288+
289+
#[tokio::test]
290+
async fn raw() {
291+
let connector = Connector::new();
292+
let endpoint = EndpointRef::new("host1", 2181, true);
293+
let err = connector.connect(endpoint, &mut Deadline::never()).await.unwrap_err();
294+
assert_eq!(err.kind(), ErrorKind::Unsupported);
295+
}
296+
}

src/session/mod.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use crate::proto::{AuthPacket, ConnectRequest, ConnectResponse, ErrorCode, OpCod
3838
use crate::record;
3939
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
4040
use crate::sasl::{SaslInitiator, SaslOptions, SaslSession};
41+
#[cfg(feature = "tls")]
4142
use crate::tls::TlsOptions;
4243

4344
pub const PASSWORD_LEN: usize = 16;
@@ -61,6 +62,7 @@ impl RequestOperation for (WatcherId, StateResponser) {
6162

6263
#[derive(Default)]
6364
pub struct Builder {
65+
#[cfg(feature = "tls")]
6466
tls: Option<TlsOptions>,
6567
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
6668
sasl: Option<SaslOptions>,
@@ -73,6 +75,7 @@ pub struct Builder {
7375
}
7476

7577
impl Builder {
78+
#[cfg(feature = "tls")]
7679
pub fn with_tls(self, tls: Option<TlsOptions>) -> Self {
7780
Self { tls, ..self }
7881
}
@@ -125,7 +128,10 @@ impl Builder {
125128
} else if self.connection_timeout < Duration::ZERO {
126129
return Err(Error::BadArguments(&"connection timeout must not be negative"));
127130
}
128-
let tls_config = self.tls.unwrap_or_default().into_config()?;
131+
#[cfg(feature = "tls")]
132+
let connector = Connector::with_tls(self.tls.unwrap_or_default().into_config()?);
133+
#[cfg(not(feature = "tls"))]
134+
let connector = Connector::new();
129135
let (state_sender, state_receiver) = tokio::sync::watch::channel(SessionState::Disconnected);
130136
let now = Instant::now();
131137
let (watch_manager, unwatch_receiver) = WatchManager::new();
@@ -142,7 +148,7 @@ impl Builder {
142148
tick_timeout: Duration::ZERO,
143149
ping_timeout: Duration::ZERO,
144150
session_expired_timeout: Duration::ZERO,
145-
connector: Connector::new(tls_config),
151+
connector,
146152
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]
147153
sasl_options: self.sasl,
148154
#[cfg(any(feature = "sasl-digest-md5", feature = "sasl-gssapi"))]

0 commit comments

Comments
 (0)