@@ -13,11 +13,11 @@ use btls::{
1313use compio:: buf:: { IoBuf , IoBufMut } ;
1414use compio:: BufResult ;
1515use compio_io:: { compat:: SyncStream , AsyncRead , AsyncWrite } ;
16- use std:: io;
17- use std:: mem:: MaybeUninit ;
16+ use std:: error:: Error ;
1817use std:: pin:: Pin ;
1918use std:: task:: Context ;
2019use std:: task:: Poll ;
20+ use std:: { fmt, io} ;
2121
2222fn cvt_ossl < T > ( r : Result < T , ssl:: Error > ) -> Poll < Result < T , ssl:: Error > > {
2323 match r {
@@ -45,25 +45,30 @@ impl<S: AsyncRead + AsyncWrite> SslStream<S> {
4545 pub fn poll_connect (
4646 self : Pin < & mut Self > ,
4747 cx : & mut Context < ' _ > ,
48- ) -> Poll < Result < ( ) , ssl :: Error > > {
48+ ) -> Poll < Result < ( ) , HandshakeError > > {
4949 self . with_context ( cx, |s| cvt_ossl ( s. connect ( ) ) )
50+ . map_err ( HandshakeError :: Ssl )
5051 }
5152
5253 #[ inline]
5354 /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
54- pub async fn connect ( self : Pin < & mut Self > ) -> Result < ( ) , ssl :: Error > {
55+ pub async fn connect ( self : Pin < & mut Self > ) -> Result < ( ) , HandshakeError > {
5556 self . drive_handshake ( |s| s. connect ( ) ) . await
5657 }
5758
5859 #[ inline]
5960 /// Like [`SslStream::accept`](ssl::SslStream::accept).
60- pub fn poll_accept ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , ssl:: Error > > {
61+ pub fn poll_accept (
62+ self : Pin < & mut Self > ,
63+ cx : & mut Context < ' _ > ,
64+ ) -> Poll < Result < ( ) , HandshakeError > > {
6165 self . with_context ( cx, |s| cvt_ossl ( s. accept ( ) ) )
66+ . map_err ( HandshakeError :: Ssl )
6267 }
6368
6469 #[ inline]
6570 /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
66- pub async fn accept ( self : Pin < & mut Self > ) -> Result < ( ) , ssl :: Error > {
71+ pub async fn accept ( self : Pin < & mut Self > ) -> Result < ( ) , HandshakeError > {
6772 self . drive_handshake ( |s| s. accept ( ) ) . await
6873 }
6974
@@ -72,17 +77,18 @@ impl<S: AsyncRead + AsyncWrite> SslStream<S> {
7277 pub fn poll_do_handshake (
7378 self : Pin < & mut Self > ,
7479 cx : & mut Context < ' _ > ,
75- ) -> Poll < Result < ( ) , ssl :: Error > > {
80+ ) -> Poll < Result < ( ) , HandshakeError > > {
7681 self . with_context ( cx, |s| cvt_ossl ( s. do_handshake ( ) ) )
82+ . map_err ( HandshakeError :: Ssl )
7783 }
7884
7985 #[ inline]
8086 /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
81- pub async fn do_handshake ( self : Pin < & mut Self > ) -> Result < ( ) , ssl :: Error > {
87+ pub async fn do_handshake ( self : Pin < & mut Self > ) -> Result < ( ) , HandshakeError > {
8288 self . drive_handshake ( |s| s. do_handshake ( ) ) . await
8389 }
8490
85- async fn drive_handshake < F > ( mut self : Pin < & mut Self > , mut f : F ) -> Result < ( ) , ssl :: Error >
91+ async fn drive_handshake < F > ( mut self : Pin < & mut Self > , mut f : F ) -> Result < ( ) , HandshakeError >
8692 where
8793 F : FnMut ( & mut SslStreamCore < SyncStream < S > > ) -> Result < ( ) , ssl:: Error > ,
8894 {
@@ -95,26 +101,32 @@ impl<S: AsyncRead + AsyncWrite> SslStream<S> {
95101 match res {
96102 Ok ( ( ) ) => {
97103 // Ensure handshake records are pushed out before returning.
98- if self . as_mut ( ) . flush_write_buf ( ) . await . is_err ( ) {
99- // Keep API compatibility: this method reports ssl::Error.
100- }
104+ self . as_mut ( )
105+ . flush_write_buf ( )
106+ . await
107+ . map_err ( HandshakeError :: Io ) ?;
108+
101109 return Ok ( ( ) ) ;
102110 }
103111 Err ( e) => match e. code ( ) {
104112 ErrorCode :: WANT_WRITE => {
105- if self . as_mut ( ) . flush_write_buf ( ) . await . is_err ( ) {
106- return Err ( e) ;
107- }
113+ self . as_mut ( )
114+ . flush_write_buf ( )
115+ . await
116+ . map_err ( HandshakeError :: Io ) ?;
108117 }
109118 ErrorCode :: WANT_READ => {
110- if self . as_mut ( ) . flush_write_buf ( ) . await . is_err ( ) {
111- return Err ( e) ;
112- }
113- if self . as_mut ( ) . fill_read_buf ( ) . await . is_err ( ) {
114- return Err ( e) ;
115- }
119+ self . as_mut ( )
120+ . flush_write_buf ( )
121+ . await
122+ . map_err ( HandshakeError :: Io ) ?;
123+
124+ self . as_mut ( )
125+ . fill_read_buf ( )
126+ . await
127+ . map_err ( HandshakeError :: Io ) ?;
116128 }
117- _ => return Err ( e ) ,
129+ _ => return Err ( HandshakeError :: Ssl ( e ) ) ,
118130 } ,
119131 }
120132 }
@@ -179,19 +191,12 @@ where
179191{
180192 async fn read < B : IoBufMut > ( & mut self , mut buf : B ) -> BufResult < usize , B > {
181193 let slice = buf. as_uninit ( ) ;
182-
183- let mut f = {
184- slice. fill ( MaybeUninit :: new ( 0 ) ) ;
185- // SAFETY: The memory has been initialized.
186- let slice =
187- unsafe { std:: slice:: from_raw_parts_mut ( slice. as_mut_ptr ( ) . cast ( ) , slice. len ( ) ) } ;
188- |s : & mut _ | std:: io:: Read :: read ( s, slice)
189- } ;
190-
191194 loop {
192- match f ( & mut self . 0 ) {
195+ // SAFETY: read_uninit does not de-initialize the buffer.
196+ match self . 0 . read_uninit ( slice) {
193197 Ok ( res) => {
194- unsafe { buf. set_len ( res) } ;
198+ // SAFETY: read_uninit guarantees that nread bytes have been initialized.
199+ unsafe { buf. advance_to ( res) } ;
195200 return BufResult ( Ok ( res) , buf) ;
196201 }
197202 Err ( e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => {
@@ -243,3 +248,58 @@ where
243248 self . 0 . get_mut ( ) . get_mut ( ) . shutdown ( ) . await
244249 }
245250}
251+
252+ /// The error type returned after a failed handshake.
253+ pub enum HandshakeError {
254+ /// An error that occurred during the SSL handshake.
255+ Ssl ( ssl:: Error ) ,
256+ /// An I/O error that occurred during the handshake.
257+ Io ( io:: Error ) ,
258+ }
259+
260+ impl HandshakeError {
261+ /// Returns the error code, if any.
262+ #[ must_use]
263+ pub fn code ( & self ) -> Option < ErrorCode > {
264+ match self {
265+ HandshakeError :: Ssl ( e) => Some ( e. code ( ) ) ,
266+ _ => None ,
267+ }
268+ }
269+
270+ /// Returns a reference to the inner I/O error, if any.
271+ #[ must_use]
272+ pub fn as_io_error ( & self ) -> Option < & io:: Error > {
273+ match self {
274+ HandshakeError :: Ssl ( e) => e. io_error ( ) ,
275+ HandshakeError :: Io ( e) => Some ( e) ,
276+ }
277+ }
278+ }
279+
280+ impl fmt:: Debug for HandshakeError {
281+ fn fmt ( & self , fmt : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
282+ match self {
283+ HandshakeError :: Ssl ( e) => fmt:: Debug :: fmt ( e, fmt) ,
284+ HandshakeError :: Io ( e) => fmt:: Debug :: fmt ( e, fmt) ,
285+ }
286+ }
287+ }
288+
289+ impl fmt:: Display for HandshakeError {
290+ fn fmt ( & self , fmt : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
291+ match self {
292+ HandshakeError :: Ssl ( e) => fmt:: Display :: fmt ( e, fmt) ,
293+ HandshakeError :: Io ( e) => fmt:: Display :: fmt ( e, fmt) ,
294+ }
295+ }
296+ }
297+
298+ impl Error for HandshakeError {
299+ fn source ( & self ) -> Option < & ( dyn Error + ' static ) > {
300+ match self {
301+ HandshakeError :: Ssl ( e) => e. source ( ) ,
302+ HandshakeError :: Io ( e) => Some ( e) ,
303+ }
304+ }
305+ }
0 commit comments