@@ -14,7 +14,7 @@ use futures_util::stream::{FuturesUnordered, StreamExt};
1414use mysql_common:: proto:: codec:: PacketCodec as PacketCodecInner ;
1515use native_tls:: { Certificate , Identity , TlsConnector } ;
1616use pin_project:: pin_project;
17- use tokio:: { net:: TcpStream , prelude:: * } ;
17+ use tokio:: { io :: ErrorKind :: Interrupted , net:: TcpStream , prelude:: * } ;
1818use tokio_util:: codec:: { Decoder , Encoder , Framed , FramedParts } ;
1919
2020use std:: {
@@ -37,6 +37,17 @@ use std::{
3737
3838use crate :: { error:: IoError , io:: socket:: Socket , opts:: SslOpts } ;
3939
40+ macro_rules! with_interrupted {
41+ ( $e: expr) => {
42+ loop {
43+ match $e {
44+ Poll :: Ready ( Err ( err) ) if err. kind( ) == Interrupted => continue ,
45+ x => break x,
46+ }
47+ }
48+ } ;
49+ }
50+
4051mod read_packet;
4152mod socket;
4253mod write_packet;
@@ -218,13 +229,14 @@ impl AsyncRead for Endpoint {
218229 cx : & mut Context ,
219230 buf : & mut [ u8 ] ,
220231 ) -> Poll < std:: result:: Result < usize , tokio:: io:: Error > > {
221- match self . project ( ) {
232+ let mut this = self . project ( ) ;
233+ with_interrupted ! ( match this {
222234 EndpointProj :: Plain ( ref mut stream) => {
223235 Pin :: new( stream. as_mut( ) . unwrap( ) ) . poll_read( cx, buf)
224236 }
225- EndpointProj :: Secure ( stream) => stream. poll_read ( cx, buf) ,
226- EndpointProj :: Socket ( stream) => stream. poll_read ( cx, buf) ,
227- }
237+ EndpointProj :: Secure ( ref mut stream) => stream. as_mut ( ) . poll_read( cx, buf) ,
238+ EndpointProj :: Socket ( ref mut stream) => stream. as_mut ( ) . poll_read( cx, buf) ,
239+ } )
228240 }
229241
230242 unsafe fn prepare_uninitialized_buffer ( & self , buf : & mut [ MaybeUninit < u8 > ] ) -> bool {
@@ -244,13 +256,14 @@ impl AsyncRead for Endpoint {
244256 where
245257 B : BufMut ,
246258 {
247- match self . project ( ) {
259+ let mut this = self . project ( ) ;
260+ with_interrupted ! ( match this {
248261 EndpointProj :: Plain ( ref mut stream) => {
249262 Pin :: new( stream. as_mut( ) . unwrap( ) ) . poll_read_buf( cx, buf)
250263 }
251- EndpointProj :: Secure ( stream) => stream. poll_read_buf ( cx, buf) ,
252- EndpointProj :: Socket ( stream) => stream. poll_read_buf ( cx, buf) ,
253- }
264+ EndpointProj :: Secure ( ref mut stream) => stream. as_mut ( ) . poll_read_buf( cx, buf) ,
265+ EndpointProj :: Socket ( ref mut stream) => stream. as_mut ( ) . poll_read_buf( cx, buf) ,
266+ } )
254267 }
255268}
256269
@@ -260,39 +273,42 @@ impl AsyncWrite for Endpoint {
260273 cx : & mut Context ,
261274 buf : & [ u8 ] ,
262275 ) -> Poll < std:: result:: Result < usize , tokio:: io:: Error > > {
263- match self . project ( ) {
276+ let mut this = self . project ( ) ;
277+ with_interrupted ! ( match this {
264278 EndpointProj :: Plain ( ref mut stream) => {
265279 Pin :: new( stream. as_mut( ) . unwrap( ) ) . poll_write( cx, buf)
266280 }
267- EndpointProj :: Secure ( stream) => stream. poll_write ( cx, buf) ,
268- EndpointProj :: Socket ( stream) => stream. poll_write ( cx, buf) ,
269- }
281+ EndpointProj :: Secure ( ref mut stream) => stream. as_mut ( ) . poll_write( cx, buf) ,
282+ EndpointProj :: Socket ( ref mut stream) => stream. as_mut ( ) . poll_write( cx, buf) ,
283+ } )
270284 }
271285
272286 fn poll_flush (
273287 self : Pin < & mut Self > ,
274288 cx : & mut Context ,
275289 ) -> Poll < std:: result:: Result < ( ) , tokio:: io:: Error > > {
276- match self . project ( ) {
290+ let mut this = self . project ( ) ;
291+ with_interrupted ! ( match this {
277292 EndpointProj :: Plain ( ref mut stream) => {
278293 Pin :: new( stream. as_mut( ) . unwrap( ) ) . poll_flush( cx)
279294 }
280- EndpointProj :: Secure ( stream) => stream. poll_flush ( cx) ,
281- EndpointProj :: Socket ( stream) => stream. poll_flush ( cx) ,
282- }
295+ EndpointProj :: Secure ( ref mut stream) => stream. as_mut ( ) . poll_flush( cx) ,
296+ EndpointProj :: Socket ( ref mut stream) => stream. as_mut ( ) . poll_flush( cx) ,
297+ } )
283298 }
284299
285300 fn poll_shutdown (
286301 self : Pin < & mut Self > ,
287302 cx : & mut Context ,
288303 ) -> Poll < std:: result:: Result < ( ) , tokio:: io:: Error > > {
289- match self . project ( ) {
304+ let mut this = self . project ( ) ;
305+ with_interrupted ! ( match this {
290306 EndpointProj :: Plain ( ref mut stream) => {
291307 Pin :: new( stream. as_mut( ) . unwrap( ) ) . poll_shutdown( cx)
292308 }
293- EndpointProj :: Secure ( stream) => stream. poll_shutdown ( cx) ,
294- EndpointProj :: Socket ( stream) => stream. poll_shutdown ( cx) ,
295- }
309+ EndpointProj :: Secure ( ref mut stream) => stream. as_mut ( ) . poll_shutdown( cx) ,
310+ EndpointProj :: Socket ( ref mut stream) => stream. as_mut ( ) . poll_shutdown( cx) ,
311+ } )
296312 }
297313}
298314
0 commit comments