Skip to content

Commit 8657d58

Browse files
committed
refactor: use noop waker for no blocking Connection io
1 parent 373d21a commit 8657d58

File tree

4 files changed

+108
-96
lines changed

4 files changed

+108
-96
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ rustls = "0.23.2"
3333
rustls-pemfile = "2"
3434
webpki-roots = "0.26.1"
3535
derive-where = "1.2.7"
36+
tokio-rustls = "0.26.0"
3637

3738
[dev-dependencies]
3839
test-log = "0.2.12"

src/session/connection.rs

+91-90
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,129 @@
1-
use std::io::{self, Read, Write};
2-
use std::sync::Arc;
1+
use std::io::{ErrorKind, IoSlice, Result};
2+
use std::pin::Pin;
3+
use std::ptr;
4+
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
35

4-
use rustls::pki_types::ServerName;
5-
use rustls::{ClientConfig, ClientConnection};
6+
use bytes::buf::BufMut;
7+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
68
use tokio::net::TcpStream;
9+
use tokio_rustls::client::TlsStream;
710

8-
use crate::error::Error;
11+
const NOOP_VTABLE: RawWakerVTable =
12+
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
13+
const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE);
914

10-
pub struct Connection {
11-
tls: Option<ClientConnection>,
12-
stream: TcpStream,
15+
pub enum Connection {
16+
Tls(TlsStream<TcpStream>),
17+
Raw(TcpStream),
1318
}
1419

15-
struct WrappingStream<'a> {
16-
stream: &'a TcpStream,
17-
}
18-
19-
impl io::Read for WrappingStream<'_> {
20-
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
21-
self.stream.try_read_buf(&mut buf)
20+
impl AsyncRead for Connection {
21+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
22+
match self.get_mut() {
23+
Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf),
24+
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
25+
}
2226
}
2327
}
2428

25-
impl io::Write for WrappingStream<'_> {
26-
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
27-
self.stream.try_write(buf)
29+
impl AsyncWrite for Connection {
30+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
31+
match self.get_mut() {
32+
Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf),
33+
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
34+
}
2835
}
2936

30-
fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
31-
self.stream.try_write_vectored(bufs)
37+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
38+
match self.get_mut() {
39+
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
40+
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
41+
}
3242
}
3343

34-
fn flush(&mut self) -> io::Result<()> {
35-
Ok(())
44+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
45+
match self.get_mut() {
46+
Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx),
47+
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
48+
}
3649
}
3750
}
3851

3952
impl Connection {
4053
pub fn new_raw(stream: TcpStream) -> Self {
41-
Self { tls: None, stream }
54+
Self::Raw(stream)
4255
}
4356

44-
pub fn new_tls(host: &str, config: Arc<ClientConfig>, stream: TcpStream) -> Result<Self, Error> {
45-
let name = match ServerName::try_from(host) {
46-
Err(_) => return Err(Error::BadArguments(&"invalid server dns name")),
47-
Ok(name) => name.to_owned(),
48-
};
49-
let client = match ClientConnection::new(config, name) {
50-
Err(err) => return Err(Error::other(format!("fail to create tls client for host({host}): {err}"), err)),
51-
Ok(client) => client,
52-
};
53-
Ok(Self { tls: Some(client), stream })
57+
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
58+
Self::Tls(stream)
5459
}
5560

56-
pub fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
57-
let Some(client) = self.tls.as_mut() else {
58-
return self.stream.try_write_vectored(bufs);
59-
};
60-
let n = client.writer().write_vectored(bufs)?;
61-
let mut stream = WrappingStream { stream: &self.stream };
62-
client.write_tls(&mut stream)?;
63-
Ok(n)
61+
pub fn try_write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
62+
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
63+
let mut context = Context::from_waker(&waker);
64+
match Pin::new(self).poll_write_vectored(&mut context, bufs) {
65+
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
66+
Poll::Ready(result) => result,
67+
}
6468
}
6569

66-
pub fn read_buf(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
67-
let Some(client) = self.tls.as_mut() else {
68-
return self.stream.try_read_buf(buf);
69-
};
70-
let mut stream = WrappingStream { stream: &self.stream };
71-
let mut read_bytes = 0;
72-
loop {
73-
match client.read_tls(&mut stream) {
74-
// We may have plaintext to return though tcp stream has been closed.
75-
// If not, read_bytes should be zero.
76-
Ok(0) => break,
77-
Ok(_) => {},
78-
Err(err) => match err.kind() {
79-
// backpressure: tls buffer is full, let's process_new_packets.
80-
io::ErrorKind::Other => {},
81-
io::ErrorKind::WouldBlock if read_bytes == 0 => {
82-
return Err(err);
83-
},
84-
_ => break,
85-
},
86-
}
87-
let state = client.process_new_packets().map_err(io::Error::other)?;
88-
let n = state.plaintext_bytes_to_read();
89-
buf.reserve(n);
90-
let slice = unsafe { &mut std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len() + n)[buf.len()..] };
91-
client.reader().read_exact(slice).unwrap();
92-
unsafe { buf.set_len(buf.len() + n) };
93-
read_bytes += n;
70+
pub fn try_read_buf(&mut self, buf: &mut impl BufMut) -> Result<usize> {
71+
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
72+
let mut context = Context::from_waker(&waker);
73+
let chunk = buf.chunk_mut();
74+
let mut read_buf = unsafe { ReadBuf::uninit(chunk.as_uninit_slice_mut()) };
75+
match Pin::new(self).poll_read(&mut context, &mut read_buf) {
76+
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
77+
Poll::Ready(Err(err)) => Err(err),
78+
Poll::Ready(Ok(())) => {
79+
let n = read_buf.filled().len();
80+
unsafe { buf.advance_mut(n) };
81+
Ok(n)
82+
},
9483
}
95-
Ok(read_bytes)
9684
}
9785

98-
pub async fn readable(&self) -> io::Result<()> {
99-
let Some(client) = self.tls.as_ref() else {
100-
return self.stream.readable().await;
101-
};
102-
if client.wants_read() {
103-
self.stream.readable().await
104-
} else {
105-
// plaintext data are available for read
106-
std::future::ready(Ok(())).await
86+
pub async fn readable(&self) -> Result<()> {
87+
match self {
88+
Self::Raw(stream) => stream.readable().await,
89+
Self::Tls(stream) => {
90+
let (stream, session) = stream.get_ref();
91+
if session.wants_read() {
92+
stream.readable().await
93+
} else {
94+
// plaintext data are available for read
95+
std::future::ready(Ok(())).await
96+
}
97+
},
10798
}
10899
}
109100

110-
pub async fn writable(&self) -> io::Result<()> {
111-
self.stream.writable().await
101+
pub async fn writable(&self) -> Result<()> {
102+
match self {
103+
Self::Raw(stream) => stream.writable().await,
104+
Self::Tls(stream) => {
105+
let (stream, _session) = stream.get_ref();
106+
stream.writable().await
107+
},
108+
}
112109
}
113110

114111
pub fn wants_write(&self) -> bool {
115-
self.tls.as_ref().map(|tls| tls.wants_write()).unwrap_or(false)
112+
match self {
113+
Self::Raw(_) => false,
114+
Self::Tls(stream) => {
115+
let (_stream, session) = stream.get_ref();
116+
session.wants_write()
117+
},
118+
}
116119
}
117120

118-
pub fn flush(&mut self) -> io::Result<()> {
119-
let Some(client) = self.tls.as_mut() else {
120-
return Ok(());
121-
};
122-
let mut stream = WrappingStream { stream: &self.stream };
123-
while client.wants_write() {
124-
client.write_tls(&mut stream)?;
121+
pub fn try_flush(&mut self) -> Result<()> {
122+
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
123+
let mut context = Context::from_waker(&waker);
124+
match Pin::new(self).poll_flush(&mut context) {
125+
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
126+
Poll::Ready(result) => result,
125127
}
126-
Ok(())
127128
}
128129
}

src/session/depot.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,15 @@ impl Depot {
206206

207207
pub fn write_operations(&mut self, conn: &mut Connection) -> Result<(), Error> {
208208
if !self.has_pending_writes() {
209-
if let Err(err) = conn.flush() {
209+
if let Err(err) = conn.try_flush() {
210210
if err.kind() == io::ErrorKind::WouldBlock {
211211
return Ok(());
212212
}
213213
return Err(Error::other_from(err));
214214
}
215215
return Ok(());
216216
}
217-
let result = conn.write_vectored(self.writing_slices.as_slice());
217+
let result = conn.try_write_vectored(self.writing_slices.as_slice());
218218
let mut written_bytes = match result {
219219
Err(err) => {
220220
if err.kind() == io::ErrorKind::WouldBlock {

src/session/mod.rs

+14-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ use std::sync::Arc;
1111
use std::time::Duration;
1212

1313
use ignore_result::Ignore;
14+
use rustls::pki_types::ServerName;
1415
use rustls::ClientConfig;
1516
use tokio::net::TcpStream;
1617
use tokio::select;
1718
use tokio::sync::mpsc;
1819
use tokio::time::{self, Instant, Sleep};
20+
use tokio_rustls::TlsConnector;
1921

2022
use self::connection::Connection;
2123
pub use self::depot::Depot;
@@ -60,7 +62,7 @@ pub struct Session {
6062
readonly: bool,
6163
detached: bool,
6264

63-
tls_config: Arc<ClientConfig>,
65+
tls_connector: TlsConnector,
6466

6567
configured_connection_timeout: Duration,
6668

@@ -104,7 +106,7 @@ impl Session {
104106
let mut session = Session {
105107
readonly,
106108
detached,
107-
tls_config: Arc::new(tls_config),
109+
tls_connector: TlsConnector::from(Arc::new(tls_config)),
108110

109111
configured_connection_timeout: connection_timeout,
110112

@@ -367,7 +369,7 @@ impl Session {
367369
}
368370

369371
fn read_connection(&mut self, conn: &mut Connection, buf: &mut Vec<u8>) -> Result<(), Error> {
370-
match conn.read_buf(buf) {
372+
match conn.try_read_buf(buf) {
371373
Ok(0) => {
372374
return Err(Error::ConnectionLoss);
373375
},
@@ -505,7 +507,15 @@ impl Session {
505507
},
506508
Ok(sock) => {
507509
let connection = if tls {
508-
Connection::new_tls(host, self.tls_config.clone(), sock)?
510+
let domain = ServerName::try_from(host).unwrap().to_owned();
511+
let stream = match self.tls_connector.connect(domain, sock).await {
512+
Err(err) => {
513+
log::debug!("ZooKeeper fails to complete tls session to {}:{} due to {}", host, port, err);
514+
return Err(Error::ConnectionLoss);
515+
},
516+
Ok(stream) => stream,
517+
};
518+
Connection::new_tls(stream)
509519
} else {
510520
Connection::new_raw(sock)
511521
};

0 commit comments

Comments
 (0)