Skip to content

Commit fbb8892

Browse files
committed
Split connection to reader and writer to multiplex them
changelog: changed
1 parent 03c7ba1 commit fbb8892

File tree

3 files changed

+81
-125
lines changed

3 files changed

+81
-125
lines changed

src/session/connection.rs

+60-73
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
use std::io::{Error, ErrorKind, IoSlice, Result};
22
use std::pin::Pin;
3-
use std::ptr;
4-
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
3+
use std::task::{Context, Poll};
54
use std::time::Duration;
65

76
use bytes::buf::BufMut;
87
use ignore_result::Ignore;
9-
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
8+
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufStream, ReadBuf};
109
use tokio::net::TcpStream;
1110
use tokio::{select, time};
1211
use tracing::{debug, trace};
@@ -26,17 +25,31 @@ use tls::*;
2625
use crate::deadline::Deadline;
2726
use crate::endpoint::{EndpointRef, IterableEndpoints};
2827

29-
const NOOP_VTABLE: RawWakerVTable =
30-
RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {});
31-
const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE);
32-
3328
#[derive(Debug)]
3429
pub enum Connection {
3530
Raw(TcpStream),
3631
#[cfg(feature = "tls")]
3732
Tls(TlsStream<TcpStream>),
3833
}
3934

35+
pub trait AsyncReadToBuf: AsyncReadExt {
36+
async fn read_to_buf(&mut self, buf: &mut impl BufMut) -> Result<usize>
37+
where
38+
Self: Unpin, {
39+
let chunk = buf.chunk_mut();
40+
let read_to = unsafe { std::mem::transmute(chunk.as_uninit_slice_mut()) };
41+
let n = self.read(read_to).await?;
42+
if n != 0 {
43+
unsafe {
44+
buf.advance_mut(n);
45+
}
46+
}
47+
Ok(n)
48+
}
49+
}
50+
51+
impl<T> AsyncReadToBuf for T where T: AsyncReadExt {}
52+
4053
impl AsyncRead for Connection {
4154
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
4255
match self.get_mut() {
@@ -56,6 +69,14 @@ impl AsyncWrite for Connection {
5669
}
5770
}
5871

72+
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize>> {
73+
match self.get_mut() {
74+
Self::Raw(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
75+
#[cfg(feature = "tls")]
76+
Self::Tls(stream) => Pin::new(stream).poll_write_vectored(cx, bufs),
77+
}
78+
}
79+
5980
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
6081
match self.get_mut() {
6182
Self::Raw(stream) => Pin::new(stream).poll_flush(cx),
@@ -73,86 +94,52 @@ impl AsyncWrite for Connection {
7394
}
7495
}
7596

76-
impl Connection {
77-
pub fn new_raw(stream: TcpStream) -> Self {
78-
Self::Raw(stream)
97+
pub struct ConnReader<'a> {
98+
conn: &'a mut Connection,
99+
}
100+
101+
impl AsyncRead for ConnReader<'_> {
102+
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
103+
Pin::new(&mut self.get_mut().conn).poll_read(cx, buf)
79104
}
105+
}
80106

81-
#[cfg(feature = "tls")]
82-
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
83-
Self::Tls(stream)
107+
pub struct ConnWriter<'a> {
108+
conn: &'a mut Connection,
109+
}
110+
111+
impl AsyncWrite for ConnWriter<'_> {
112+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
113+
Pin::new(&mut self.get_mut().conn).poll_write(cx, buf)
84114
}
85115

86-
pub fn try_write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> {
87-
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
88-
let mut context = Context::from_waker(&waker);
89-
match Pin::new(self).poll_write_vectored(&mut context, bufs) {
90-
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
91-
Poll::Ready(result) => result,
92-
}
116+
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize>> {
117+
Pin::new(&mut self.get_mut().conn).poll_write_vectored(cx, bufs)
93118
}
94119

95-
pub fn try_read_buf(&mut self, buf: &mut impl BufMut) -> Result<usize> {
96-
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
97-
let mut context = Context::from_waker(&waker);
98-
let chunk = buf.chunk_mut();
99-
let mut read_buf = unsafe { ReadBuf::uninit(chunk.as_uninit_slice_mut()) };
100-
match Pin::new(self).poll_read(&mut context, &mut read_buf) {
101-
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
102-
Poll::Ready(Err(err)) => Err(err),
103-
Poll::Ready(Ok(())) => {
104-
let n = read_buf.filled().len();
105-
unsafe { buf.advance_mut(n) };
106-
Ok(n)
107-
},
108-
}
120+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
121+
Pin::new(&mut self.get_mut().conn).poll_flush(cx)
109122
}
110123

111-
pub async fn readable(&self) -> Result<()> {
112-
match self {
113-
Self::Raw(stream) => stream.readable().await,
114-
#[cfg(feature = "tls")]
115-
Self::Tls(stream) => {
116-
let (stream, session) = stream.get_ref();
117-
if session.wants_read() {
118-
stream.readable().await
119-
} else {
120-
// plaintext data are available for read
121-
std::future::ready(Ok(())).await
122-
}
123-
},
124-
}
124+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
125+
Pin::new(&mut self.get_mut().conn).poll_shutdown(cx)
125126
}
127+
}
126128

127-
pub async fn writable(&self) -> Result<()> {
128-
match self {
129-
Self::Raw(stream) => stream.writable().await,
130-
#[cfg(feature = "tls")]
131-
Self::Tls(stream) => {
132-
let (stream, _session) = stream.get_ref();
133-
stream.writable().await
134-
},
135-
}
129+
impl Connection {
130+
pub fn new_raw(stream: TcpStream) -> Self {
131+
Self::Raw(stream)
136132
}
137133

138-
pub fn wants_write(&self) -> bool {
139-
match self {
140-
Self::Raw(_) => false,
141-
#[cfg(feature = "tls")]
142-
Self::Tls(stream) => {
143-
let (_stream, session) = stream.get_ref();
144-
session.wants_write()
145-
},
146-
}
134+
pub fn split(&mut self) -> (ConnReader<'_>, ConnWriter<'_>) {
135+
let reader = ConnReader { conn: self };
136+
let writer = ConnWriter { conn: unsafe { std::ptr::read(&reader.conn) } };
137+
(reader, writer)
147138
}
148139

149-
pub fn try_flush(&mut self) -> Result<()> {
150-
let waker = unsafe { Waker::from_raw(NOOP_WAKER) };
151-
let mut context = Context::from_waker(&waker);
152-
match Pin::new(self).poll_flush(&mut context) {
153-
Poll::Pending => Err(ErrorKind::WouldBlock.into()),
154-
Poll::Ready(result) => result,
155-
}
140+
#[cfg(feature = "tls")]
141+
pub fn new_tls(stream: TlsStream<TcpStream>) -> Self {
142+
Self::Tls(stream)
156143
}
157144

158145
pub async fn command(self, cmd: &str) -> Result<String> {

src/session/depot.rs

+5-20
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use std::collections::VecDeque;
2-
use std::io::{self, IoSlice};
2+
use std::io::IoSlice;
33

44
use hashbrown::HashMap;
55
use strum::IntoEnumIterator;
6+
use tokio::io::AsyncWriteExt;
67
use tracing::debug;
78

8-
use super::connection::Connection;
99
use super::request::{MarshalledRequest, OpStat, Operation, SessionOperation, StateResponser};
1010
use super::types::WatchMode;
1111
use super::xid::Xid;
@@ -229,26 +229,11 @@ impl Depot {
229229
.any(|mode| self.watching_paths.contains_key(&(path, mode)))
230230
}
231231

232-
pub fn write_operations(&mut self, conn: &mut Connection) -> Result<(), Error> {
232+
pub async fn write_to(&mut self, write: &mut (impl AsyncWriteExt + Unpin)) -> Result<(), Error> {
233233
if !self.has_pending_writes() {
234-
if let Err(err) = conn.try_flush() {
235-
if err.kind() == io::ErrorKind::WouldBlock {
236-
return Ok(());
237-
}
238-
return Err(Error::other(err));
239-
}
240-
return Ok(());
234+
return std::future::pending().await;
241235
}
242-
let result = conn.try_write_vectored(self.writing_slices.as_slice());
243-
let mut written_bytes = match result {
244-
Err(err) => {
245-
if err.kind() == io::ErrorKind::WouldBlock {
246-
return Ok(());
247-
}
248-
return Err(Error::other(err));
249-
},
250-
Ok(written_bytes) => written_bytes,
251-
};
236+
let mut written_bytes = write.write_vectored(self.writing_slices.as_slice()).await.map_err(Error::other)?;
252237
let written_slices = self
253238
.writing_slices
254239
.iter()

src/session/mod.rs

+16-32
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ mod types;
66
mod watch;
77
mod xid;
88

9-
use std::io;
109
use std::time::Duration;
1110

1211
use ignore_result::Ignore;
12+
use tokio::io::AsyncWriteExt;
1313
use tokio::select;
1414
use tokio::sync::mpsc;
1515
use tokio::time::{self, Instant};
1616
use tracing::field::display;
1717
use tracing::{debug, info, instrument, warn, Span};
1818

19-
use self::connection::{Connection, Connector};
19+
use self::connection::{AsyncReadToBuf, Connection, Connector};
2020
pub use self::depot::Depot;
2121
use self::event::WatcherEvent;
2222
pub use self::request::{
@@ -478,21 +478,6 @@ impl Session {
478478
Ok(())
479479
}
480480

481-
fn read_connection(&mut self, conn: &mut Connection, buf: &mut Vec<u8>) -> Result<(), Error> {
482-
match conn.try_read_buf(buf) {
483-
Ok(0) => {
484-
return Err(Error::ConnectionLoss);
485-
},
486-
Err(err) => {
487-
if err.kind() != io::ErrorKind::WouldBlock {
488-
return Err(Error::other(err));
489-
}
490-
},
491-
_ => {},
492-
}
493-
Ok(())
494-
}
495-
496481
fn handle_recv_buf(&mut self, recved: &mut Vec<u8>, depot: &mut Depot) -> Result<(), Error> {
497482
let mut reading = recved.as_slice();
498483
if self.session_state == SessionState::Disconnected {
@@ -522,14 +507,15 @@ impl Session {
522507
let mut pinged = false;
523508
let mut tick = time::interval(self.tick_timeout);
524509
tick.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
510+
let (mut reader, mut writer) = conn.split();
525511
while !(self.session_state.is_connected() && depot.is_empty()) {
526512
select! {
527-
_ = conn.readable() => {
528-
self.read_connection(conn, buf)?;
529-
self.handle_recv_buf(buf, depot)?;
513+
r = reader.read_to_buf(buf) => match r.map_err(Error::other)? {
514+
0 => return Err(Error::ConnectionLoss),
515+
_ => self.handle_recv_buf(buf, depot)?,
530516
},
531-
_ = conn.writable(), if depot.has_pending_writes() || conn.wants_write() => {
532-
depot.write_operations(conn)?;
517+
r = depot.write_to(&mut writer) => {
518+
r?;
533519
self.last_send = Instant::now();
534520
},
535521
now = tick.tick() => {
@@ -543,7 +529,6 @@ impl Session {
543529
// "zookeeper.enforce.auth.enabled".
544530
pinged = true;
545531
self.send_ping(depot, Instant::now());
546-
depot.write_operations(conn)?;
547532
}
548533
}
549534
Ok(())
@@ -574,19 +559,20 @@ impl Session {
574559
let mut err = None;
575560
let mut channel_halted = false;
576561
depot.start();
577-
while !(channel_halted && depot.is_empty() && !conn.wants_write()) {
562+
let (mut reader, mut writer) = conn.split();
563+
while !(channel_halted && depot.is_empty()) {
578564
select! {
579565
Some(endpoint) = Self::poll(&mut seek_for_writable), if seek_for_writable.is_some() => {
580566
seek_for_writable = None;
581567
err = Some(Error::with_message(format!("encounter writable server {}", endpoint)));
582568
channel_halted = true;
583569
},
584-
_ = conn.readable() => {
585-
self.read_connection(conn, buf)?;
586-
self.handle_recv_buf(buf, depot)?;
570+
r = reader.read_to_buf(buf) => match r.map_err(Error::other)? {
571+
0 => return Err(Error::ConnectionLoss),
572+
_ => self.handle_recv_buf(buf, depot)?,
587573
},
588-
_ = conn.writable(), if depot.has_pending_writes() || conn.wants_write() => {
589-
depot.write_operations(conn)?;
574+
r = depot.write_to(&mut writer) => {
575+
r?;
590576
self.last_send = Instant::now();
591577
},
592578
r = requester.recv(), if !channel_halted => {
@@ -600,8 +586,6 @@ impl Session {
600586
continue;
601587
};
602588
depot.push_session(operation);
603-
depot.write_operations(conn)?;
604-
self.last_send = Instant::now();
605589
},
606590
r = unwatch_requester.recv() => if let Some((watcher_id, responser)) = r {
607591
self.watch_manager.remove_watcher(watcher_id, responser, depot);
@@ -612,11 +596,11 @@ impl Session {
612596
}
613597
if self.last_ping.is_none() && now >= self.last_send + self.ping_timeout {
614598
self.send_ping(depot, now);
615-
depot.write_operations(conn)?;
616599
}
617600
},
618601
}
619602
}
603+
writer.flush().await.map_err(Error::other)?;
620604
Err(err.unwrap_or(Error::ClientClosed))
621605
}
622606

0 commit comments

Comments
 (0)