Skip to content

Commit 4602469

Browse files
committed
io: Retry on Interrupted
1 parent 042c16b commit 4602469

File tree

2 files changed

+53
-27
lines changed

2 files changed

+53
-27
lines changed

src/io/mod.rs

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use futures_util::stream::{FuturesUnordered, StreamExt};
1414
use mysql_common::proto::codec::PacketCodec as PacketCodecInner;
1515
use native_tls::{Certificate, Identity, TlsConnector};
1616
use pin_project::pin_project;
17-
use tokio::{net::TcpStream, prelude::*};
17+
use tokio::{io::ErrorKind::Interrupted, net::TcpStream, prelude::*};
1818
use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts};
1919

2020
use std::{
@@ -37,6 +37,17 @@ use std::{
3737

3838
use crate::{error::IoError, io::socket::Socket, opts::SslOpts};
3939

40+
macro_rules! with_interrupted {
41+
($e:expr) => {
42+
loop {
43+
match $e {
44+
Poll::Ready(Err(err)) if err.kind() == Interrupted => continue,
45+
x => break x,
46+
}
47+
}
48+
};
49+
}
50+
4051
mod read_packet;
4152
mod socket;
4253
mod write_packet;
@@ -218,13 +229,14 @@ impl AsyncRead for Endpoint {
218229
cx: &mut Context,
219230
buf: &mut [u8],
220231
) -> Poll<std::result::Result<usize, tokio::io::Error>> {
221-
match self.project() {
232+
let mut this = self.project();
233+
with_interrupted!(match this {
222234
EndpointProj::Plain(ref mut stream) => {
223235
Pin::new(stream.as_mut().unwrap()).poll_read(cx, buf)
224236
}
225-
EndpointProj::Secure(stream) => stream.poll_read(cx, buf),
226-
EndpointProj::Socket(stream) => stream.poll_read(cx, buf),
227-
}
237+
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read(cx, buf),
238+
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read(cx, buf),
239+
})
228240
}
229241

230242
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
@@ -244,13 +256,14 @@ impl AsyncRead for Endpoint {
244256
where
245257
B: BufMut,
246258
{
247-
match self.project() {
259+
let mut this = self.project();
260+
with_interrupted!(match this {
248261
EndpointProj::Plain(ref mut stream) => {
249262
Pin::new(stream.as_mut().unwrap()).poll_read_buf(cx, buf)
250263
}
251-
EndpointProj::Secure(stream) => stream.poll_read_buf(cx, buf),
252-
EndpointProj::Socket(stream) => stream.poll_read_buf(cx, buf),
253-
}
264+
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_read_buf(cx, buf),
265+
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_read_buf(cx, buf),
266+
})
254267
}
255268
}
256269

@@ -260,39 +273,42 @@ impl AsyncWrite for Endpoint {
260273
cx: &mut Context,
261274
buf: &[u8],
262275
) -> Poll<std::result::Result<usize, tokio::io::Error>> {
263-
match self.project() {
276+
let mut this = self.project();
277+
with_interrupted!(match this {
264278
EndpointProj::Plain(ref mut stream) => {
265279
Pin::new(stream.as_mut().unwrap()).poll_write(cx, buf)
266280
}
267-
EndpointProj::Secure(stream) => stream.poll_write(cx, buf),
268-
EndpointProj::Socket(stream) => stream.poll_write(cx, buf),
269-
}
281+
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_write(cx, buf),
282+
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_write(cx, buf),
283+
})
270284
}
271285

272286
fn poll_flush(
273287
self: Pin<&mut Self>,
274288
cx: &mut Context,
275289
) -> Poll<std::result::Result<(), tokio::io::Error>> {
276-
match self.project() {
290+
let mut this = self.project();
291+
with_interrupted!(match this {
277292
EndpointProj::Plain(ref mut stream) => {
278293
Pin::new(stream.as_mut().unwrap()).poll_flush(cx)
279294
}
280-
EndpointProj::Secure(stream) => stream.poll_flush(cx),
281-
EndpointProj::Socket(stream) => stream.poll_flush(cx),
282-
}
295+
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_flush(cx),
296+
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_flush(cx),
297+
})
283298
}
284299

285300
fn poll_shutdown(
286301
self: Pin<&mut Self>,
287302
cx: &mut Context,
288303
) -> Poll<std::result::Result<(), tokio::io::Error>> {
289-
match self.project() {
304+
let mut this = self.project();
305+
with_interrupted!(match this {
290306
EndpointProj::Plain(ref mut stream) => {
291307
Pin::new(stream.as_mut().unwrap()).poll_shutdown(cx)
292308
}
293-
EndpointProj::Secure(stream) => stream.poll_shutdown(cx),
294-
EndpointProj::Socket(stream) => stream.poll_shutdown(cx),
295-
}
309+
EndpointProj::Secure(ref mut stream) => stream.as_mut().poll_shutdown(cx),
310+
EndpointProj::Socket(ref mut stream) => stream.as_mut().poll_shutdown(cx),
311+
})
296312
}
297313
}
298314

src/io/socket.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
use bytes::BufMut;
1010
use pin_project::pin_project;
11-
use tokio::{io::Error, prelude::*};
11+
use tokio::{
12+
io::{Error, ErrorKind::Interrupted},
13+
prelude::*,
14+
};
1215

1316
use std::{
1417
io,
@@ -56,7 +59,8 @@ impl AsyncRead for Socket {
5659
cx: &mut Context,
5760
buf: &mut [u8],
5861
) -> Poll<Result<usize, Error>> {
59-
self.project().inner.poll_read(cx, buf)
62+
let mut this = self.project();
63+
with_interrupted!(this.inner.as_mut().poll_read(cx, buf))
6064
}
6165

6266
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
@@ -71,7 +75,8 @@ impl AsyncRead for Socket {
7175
where
7276
B: BufMut,
7377
{
74-
self.project().inner.poll_read_buf(cx, buf)
78+
let mut this = self.project();
79+
with_interrupted!(this.inner.as_mut().poll_read_buf(cx, buf))
7580
}
7681
}
7782

@@ -81,12 +86,17 @@ impl AsyncWrite for Socket {
8186
cx: &mut Context,
8287
buf: &[u8],
8388
) -> Poll<Result<usize, Error>> {
84-
self.project().inner.poll_write(cx, buf)
89+
let mut this = self.project();
90+
with_interrupted!(this.inner.as_mut().poll_write(cx, buf))
8591
}
92+
8693
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
87-
self.project().inner.poll_flush(cx)
94+
let mut this = self.project();
95+
with_interrupted!(this.inner.as_mut().poll_flush(cx))
8896
}
97+
8998
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
90-
self.project().inner.poll_shutdown(cx)
99+
let mut this = self.project();
100+
with_interrupted!(this.inner.as_mut().poll_shutdown(cx))
91101
}
92102
}

0 commit comments

Comments
 (0)