|
1 |
| -use std::io::{self, Read, Write}; |
2 |
| -use std::sync::Arc; |
| 1 | +use std::io::{ErrorKind, IoSlice, Result}; |
| 2 | +use std::pin::Pin; |
| 3 | +use std::ptr; |
| 4 | +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; |
3 | 5 |
|
4 |
| -use rustls::pki_types::ServerName; |
5 |
| -use rustls::{ClientConfig, ClientConnection}; |
| 6 | +use bytes::buf::BufMut; |
| 7 | +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
6 | 8 | use tokio::net::TcpStream;
|
| 9 | +use tokio_rustls::client::TlsStream; |
7 | 10 |
|
8 |
| -use crate::error::Error; |
| 11 | +const NOOP_VTABLE: RawWakerVTable = |
| 12 | + RawWakerVTable::new(|_| RawWaker::new(ptr::null(), &NOOP_VTABLE), |_| {}, |_| {}, |_| {}); |
| 13 | +const NOOP_WAKER: RawWaker = RawWaker::new(ptr::null(), &NOOP_VTABLE); |
9 | 14 |
|
10 |
| -pub struct Connection { |
11 |
| - tls: Option<ClientConnection>, |
12 |
| - stream: TcpStream, |
| 15 | +pub enum Connection { |
| 16 | + Tls(TlsStream<TcpStream>), |
| 17 | + Raw(TcpStream), |
13 | 18 | }
|
14 | 19 |
|
15 |
| -struct WrappingStream<'a> { |
16 |
| - stream: &'a TcpStream, |
17 |
| -} |
18 |
| - |
19 |
| -impl io::Read for WrappingStream<'_> { |
20 |
| - fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> { |
21 |
| - self.stream.try_read_buf(&mut buf) |
| 20 | +impl AsyncRead for Connection { |
| 21 | + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> { |
| 22 | + match self.get_mut() { |
| 23 | + Self::Raw(stream) => Pin::new(stream).poll_read(cx, buf), |
| 24 | + Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf), |
| 25 | + } |
22 | 26 | }
|
23 | 27 | }
|
24 | 28 |
|
25 |
| -impl io::Write for WrappingStream<'_> { |
26 |
| - fn write(&mut self, buf: &[u8]) -> io::Result<usize> { |
27 |
| - self.stream.try_write(buf) |
| 29 | +impl AsyncWrite for Connection { |
| 30 | + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> { |
| 31 | + match self.get_mut() { |
| 32 | + Self::Raw(stream) => Pin::new(stream).poll_write(cx, buf), |
| 33 | + Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf), |
| 34 | + } |
28 | 35 | }
|
29 | 36 |
|
30 |
| - fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { |
31 |
| - self.stream.try_write_vectored(bufs) |
| 37 | + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { |
| 38 | + match self.get_mut() { |
| 39 | + Self::Raw(stream) => Pin::new(stream).poll_flush(cx), |
| 40 | + Self::Tls(stream) => Pin::new(stream).poll_flush(cx), |
| 41 | + } |
32 | 42 | }
|
33 | 43 |
|
34 |
| - fn flush(&mut self) -> io::Result<()> { |
35 |
| - Ok(()) |
| 44 | + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> { |
| 45 | + match self.get_mut() { |
| 46 | + Self::Raw(stream) => Pin::new(stream).poll_shutdown(cx), |
| 47 | + Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx), |
| 48 | + } |
36 | 49 | }
|
37 | 50 | }
|
38 | 51 |
|
39 | 52 | impl Connection {
|
40 | 53 | pub fn new_raw(stream: TcpStream) -> Self {
|
41 |
| - Self { tls: None, stream } |
| 54 | + Self::Raw(stream) |
42 | 55 | }
|
43 | 56 |
|
44 |
| - pub fn new_tls(host: &str, config: Arc<ClientConfig>, stream: TcpStream) -> Result<Self, Error> { |
45 |
| - let name = match ServerName::try_from(host) { |
46 |
| - Err(_) => return Err(Error::BadArguments(&"invalid server dns name")), |
47 |
| - Ok(name) => name.to_owned(), |
48 |
| - }; |
49 |
| - let client = match ClientConnection::new(config, name) { |
50 |
| - Err(err) => return Err(Error::other(format!("fail to create tls client for host({host}): {err}"), err)), |
51 |
| - Ok(client) => client, |
52 |
| - }; |
53 |
| - Ok(Self { tls: Some(client), stream }) |
| 57 | + pub fn new_tls(stream: TlsStream<TcpStream>) -> Self { |
| 58 | + Self::Tls(stream) |
54 | 59 | }
|
55 | 60 |
|
56 |
| - pub fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { |
57 |
| - let Some(client) = self.tls.as_mut() else { |
58 |
| - return self.stream.try_write_vectored(bufs); |
59 |
| - }; |
60 |
| - let n = client.writer().write_vectored(bufs)?; |
61 |
| - let mut stream = WrappingStream { stream: &self.stream }; |
62 |
| - client.write_tls(&mut stream)?; |
63 |
| - Ok(n) |
| 61 | + pub fn try_write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> Result<usize> { |
| 62 | + let waker = unsafe { Waker::from_raw(NOOP_WAKER) }; |
| 63 | + let mut context = Context::from_waker(&waker); |
| 64 | + match Pin::new(self).poll_write_vectored(&mut context, bufs) { |
| 65 | + Poll::Pending => Err(ErrorKind::WouldBlock.into()), |
| 66 | + Poll::Ready(result) => result, |
| 67 | + } |
64 | 68 | }
|
65 | 69 |
|
66 |
| - pub fn read_buf(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> { |
67 |
| - let Some(client) = self.tls.as_mut() else { |
68 |
| - return self.stream.try_read_buf(buf); |
69 |
| - }; |
70 |
| - let mut stream = WrappingStream { stream: &self.stream }; |
71 |
| - let mut read_bytes = 0; |
72 |
| - loop { |
73 |
| - match client.read_tls(&mut stream) { |
74 |
| - // We may have plaintext to return though tcp stream has been closed. |
75 |
| - // If not, read_bytes should be zero. |
76 |
| - Ok(0) => break, |
77 |
| - Ok(_) => {}, |
78 |
| - Err(err) => match err.kind() { |
79 |
| - // backpressure: tls buffer is full, let's process_new_packets. |
80 |
| - io::ErrorKind::Other => {}, |
81 |
| - io::ErrorKind::WouldBlock if read_bytes == 0 => { |
82 |
| - return Err(err); |
83 |
| - }, |
84 |
| - _ => break, |
85 |
| - }, |
86 |
| - } |
87 |
| - let state = client.process_new_packets().map_err(io::Error::other)?; |
88 |
| - let n = state.plaintext_bytes_to_read(); |
89 |
| - buf.reserve(n); |
90 |
| - let slice = unsafe { &mut std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len() + n)[buf.len()..] }; |
91 |
| - client.reader().read_exact(slice).unwrap(); |
92 |
| - unsafe { buf.set_len(buf.len() + n) }; |
93 |
| - read_bytes += n; |
| 70 | + pub fn try_read_buf(&mut self, buf: &mut impl BufMut) -> Result<usize> { |
| 71 | + let waker = unsafe { Waker::from_raw(NOOP_WAKER) }; |
| 72 | + let mut context = Context::from_waker(&waker); |
| 73 | + let chunk = buf.chunk_mut(); |
| 74 | + let mut read_buf = unsafe { ReadBuf::uninit(chunk.as_uninit_slice_mut()) }; |
| 75 | + match Pin::new(self).poll_read(&mut context, &mut read_buf) { |
| 76 | + Poll::Pending => Err(ErrorKind::WouldBlock.into()), |
| 77 | + Poll::Ready(Err(err)) => Err(err), |
| 78 | + Poll::Ready(Ok(())) => { |
| 79 | + let n = read_buf.filled().len(); |
| 80 | + unsafe { buf.advance_mut(n) }; |
| 81 | + Ok(n) |
| 82 | + }, |
94 | 83 | }
|
95 |
| - Ok(read_bytes) |
96 | 84 | }
|
97 | 85 |
|
98 |
| - pub async fn readable(&self) -> io::Result<()> { |
99 |
| - let Some(client) = self.tls.as_ref() else { |
100 |
| - return self.stream.readable().await; |
101 |
| - }; |
102 |
| - if client.wants_read() { |
103 |
| - self.stream.readable().await |
104 |
| - } else { |
105 |
| - // plaintext data are available for read |
106 |
| - std::future::ready(Ok(())).await |
| 86 | + pub async fn readable(&self) -> Result<()> { |
| 87 | + match self { |
| 88 | + Self::Raw(stream) => stream.readable().await, |
| 89 | + Self::Tls(stream) => { |
| 90 | + let (stream, session) = stream.get_ref(); |
| 91 | + if session.wants_read() { |
| 92 | + stream.readable().await |
| 93 | + } else { |
| 94 | + // plaintext data are available for read |
| 95 | + std::future::ready(Ok(())).await |
| 96 | + } |
| 97 | + }, |
107 | 98 | }
|
108 | 99 | }
|
109 | 100 |
|
110 |
| - pub async fn writable(&self) -> io::Result<()> { |
111 |
| - self.stream.writable().await |
| 101 | + pub async fn writable(&self) -> Result<()> { |
| 102 | + match self { |
| 103 | + Self::Raw(stream) => stream.writable().await, |
| 104 | + Self::Tls(stream) => { |
| 105 | + let (stream, _session) = stream.get_ref(); |
| 106 | + stream.writable().await |
| 107 | + }, |
| 108 | + } |
112 | 109 | }
|
113 | 110 |
|
114 | 111 | pub fn wants_write(&self) -> bool {
|
115 |
| - self.tls.as_ref().map(|tls| tls.wants_write()).unwrap_or(false) |
| 112 | + match self { |
| 113 | + Self::Raw(_) => false, |
| 114 | + Self::Tls(stream) => { |
| 115 | + let (_stream, session) = stream.get_ref(); |
| 116 | + session.wants_write() |
| 117 | + }, |
| 118 | + } |
116 | 119 | }
|
117 | 120 |
|
118 |
| - pub fn flush(&mut self) -> io::Result<()> { |
119 |
| - let Some(client) = self.tls.as_mut() else { |
120 |
| - return Ok(()); |
121 |
| - }; |
122 |
| - let mut stream = WrappingStream { stream: &self.stream }; |
123 |
| - while client.wants_write() { |
124 |
| - client.write_tls(&mut stream)?; |
| 121 | + pub fn try_flush(&mut self) -> Result<()> { |
| 122 | + let waker = unsafe { Waker::from_raw(NOOP_WAKER) }; |
| 123 | + let mut context = Context::from_waker(&waker); |
| 124 | + match Pin::new(self).poll_flush(&mut context) { |
| 125 | + Poll::Pending => Err(ErrorKind::WouldBlock.into()), |
| 126 | + Poll::Ready(result) => result, |
125 | 127 | }
|
126 |
| - Ok(()) |
127 | 128 | }
|
128 | 129 | }
|
0 commit comments