Skip to content

Commit 664e950

Browse files
committed
error
1 parent 760fb7e commit 664e950

File tree

1 file changed

+93
-33
lines changed

1 file changed

+93
-33
lines changed

compio-btls/src/lib.rs

Lines changed: 93 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ use btls::{
1313
use compio::buf::{IoBuf, IoBufMut};
1414
use compio::BufResult;
1515
use compio_io::{compat::SyncStream, AsyncRead, AsyncWrite};
16-
use std::io;
17-
use std::mem::MaybeUninit;
16+
use std::error::Error;
1817
use std::pin::Pin;
1918
use std::task::Context;
2019
use std::task::Poll;
20+
use std::{fmt, io};
2121

2222
fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
2323
match r {
@@ -45,25 +45,30 @@ impl<S: AsyncRead + AsyncWrite> SslStream<S> {
4545
pub fn poll_connect(
4646
self: Pin<&mut Self>,
4747
cx: &mut Context<'_>,
48-
) -> Poll<Result<(), ssl::Error>> {
48+
) -> Poll<Result<(), HandshakeError>> {
4949
self.with_context(cx, |s| cvt_ossl(s.connect()))
50+
.map_err(HandshakeError::Ssl)
5051
}
5152

5253
#[inline]
5354
/// A convenience method wrapping [`poll_connect`](Self::poll_connect).
54-
pub async fn connect(self: Pin<&mut Self>) -> Result<(), ssl::Error> {
55+
pub async fn connect(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
5556
self.drive_handshake(|s| s.connect()).await
5657
}
5758

5859
#[inline]
5960
/// Like [`SslStream::accept`](ssl::SslStream::accept).
60-
pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
61+
pub fn poll_accept(
62+
self: Pin<&mut Self>,
63+
cx: &mut Context<'_>,
64+
) -> Poll<Result<(), HandshakeError>> {
6165
self.with_context(cx, |s| cvt_ossl(s.accept()))
66+
.map_err(HandshakeError::Ssl)
6267
}
6368

6469
#[inline]
6570
/// A convenience method wrapping [`poll_accept`](Self::poll_accept).
66-
pub async fn accept(self: Pin<&mut Self>) -> Result<(), ssl::Error> {
71+
pub async fn accept(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
6772
self.drive_handshake(|s| s.accept()).await
6873
}
6974

@@ -72,17 +77,18 @@ impl<S: AsyncRead + AsyncWrite> SslStream<S> {
7277
pub fn poll_do_handshake(
7378
self: Pin<&mut Self>,
7479
cx: &mut Context<'_>,
75-
) -> Poll<Result<(), ssl::Error>> {
80+
) -> Poll<Result<(), HandshakeError>> {
7681
self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
82+
.map_err(HandshakeError::Ssl)
7783
}
7884

7985
#[inline]
8086
/// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
81-
pub async fn do_handshake(self: Pin<&mut Self>) -> Result<(), ssl::Error> {
87+
pub async fn do_handshake(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
8288
self.drive_handshake(|s| s.do_handshake()).await
8389
}
8490

85-
async fn drive_handshake<F>(mut self: Pin<&mut Self>, mut f: F) -> Result<(), ssl::Error>
91+
async fn drive_handshake<F>(mut self: Pin<&mut Self>, mut f: F) -> Result<(), HandshakeError>
8692
where
8793
F: FnMut(&mut SslStreamCore<SyncStream<S>>) -> Result<(), ssl::Error>,
8894
{
@@ -95,26 +101,32 @@ impl<S: AsyncRead + AsyncWrite> SslStream<S> {
95101
match res {
96102
Ok(()) => {
97103
// Ensure handshake records are pushed out before returning.
98-
if self.as_mut().flush_write_buf().await.is_err() {
99-
// Keep API compatibility: this method reports ssl::Error.
100-
}
104+
self.as_mut()
105+
.flush_write_buf()
106+
.await
107+
.map_err(HandshakeError::Io)?;
108+
101109
return Ok(());
102110
}
103111
Err(e) => match e.code() {
104112
ErrorCode::WANT_WRITE => {
105-
if self.as_mut().flush_write_buf().await.is_err() {
106-
return Err(e);
107-
}
113+
self.as_mut()
114+
.flush_write_buf()
115+
.await
116+
.map_err(HandshakeError::Io)?;
108117
}
109118
ErrorCode::WANT_READ => {
110-
if self.as_mut().flush_write_buf().await.is_err() {
111-
return Err(e);
112-
}
113-
if self.as_mut().fill_read_buf().await.is_err() {
114-
return Err(e);
115-
}
119+
self.as_mut()
120+
.flush_write_buf()
121+
.await
122+
.map_err(HandshakeError::Io)?;
123+
124+
self.as_mut()
125+
.fill_read_buf()
126+
.await
127+
.map_err(HandshakeError::Io)?;
116128
}
117-
_ => return Err(e),
129+
_ => return Err(HandshakeError::Ssl(e)),
118130
},
119131
}
120132
}
@@ -179,19 +191,12 @@ where
179191
{
180192
async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
181193
let slice = buf.as_uninit();
182-
183-
let mut f = {
184-
slice.fill(MaybeUninit::new(0));
185-
// SAFETY: The memory has been initialized.
186-
let slice =
187-
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) };
188-
|s: &mut _| std::io::Read::read(s, slice)
189-
};
190-
191194
loop {
192-
match f(&mut self.0) {
195+
// SAFETY: read_uninit does not de-initialize the buffer.
196+
match self.0.read_uninit(slice) {
193197
Ok(res) => {
194-
unsafe { buf.set_len(res) };
198+
// SAFETY: read_uninit guarantees that nread bytes have been initialized.
199+
unsafe { buf.advance_to(res) };
195200
return BufResult(Ok(res), buf);
196201
}
197202
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
@@ -243,3 +248,58 @@ where
243248
self.0.get_mut().get_mut().shutdown().await
244249
}
245250
}
251+
252+
/// The error type returned after a failed handshake.
253+
pub enum HandshakeError {
254+
/// An error that occurred during the SSL handshake.
255+
Ssl(ssl::Error),
256+
/// An I/O error that occurred during the handshake.
257+
Io(io::Error),
258+
}
259+
260+
impl HandshakeError {
261+
/// Returns the error code, if any.
262+
#[must_use]
263+
pub fn code(&self) -> Option<ErrorCode> {
264+
match self {
265+
HandshakeError::Ssl(e) => Some(e.code()),
266+
_ => None,
267+
}
268+
}
269+
270+
/// Returns a reference to the inner I/O error, if any.
271+
#[must_use]
272+
pub fn as_io_error(&self) -> Option<&io::Error> {
273+
match self {
274+
HandshakeError::Ssl(e) => e.io_error(),
275+
HandshakeError::Io(e) => Some(e),
276+
}
277+
}
278+
}
279+
280+
impl fmt::Debug for HandshakeError {
281+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
282+
match self {
283+
HandshakeError::Ssl(e) => fmt::Debug::fmt(e, fmt),
284+
HandshakeError::Io(e) => fmt::Debug::fmt(e, fmt),
285+
}
286+
}
287+
}
288+
289+
impl fmt::Display for HandshakeError {
290+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
291+
match self {
292+
HandshakeError::Ssl(e) => fmt::Display::fmt(e, fmt),
293+
HandshakeError::Io(e) => fmt::Display::fmt(e, fmt),
294+
}
295+
}
296+
}
297+
298+
impl Error for HandshakeError {
299+
fn source(&self) -> Option<&(dyn Error + 'static)> {
300+
match self {
301+
HandshakeError::Ssl(e) => e.source(),
302+
HandshakeError::Io(e) => Some(e),
303+
}
304+
}
305+
}

0 commit comments

Comments
 (0)