Skip to content

Commit 3b12a77

Browse files
authored
Fix io close (#12)
* Fix io close for Framed * Fix connection shutdown for h1 dispatcher * Enable client disconnect for http server by default * Add connection disconnect timeout to framed service
1 parent 8a753a7 commit 3b12a77

File tree

21 files changed

+528
-168
lines changed

21 files changed

+528
-168
lines changed

ntex-codec/CHANGES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changes
22

3+
## [0.1.1] - 2020-04-07
4+
5+
* Optimize io operations
6+
7+
* Fix framed close method
8+
39
## [0.1.0] - 2020-03-31
410

511
* Fork crate to ntex namespace

ntex-codec/Cargo.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
[package]
22
name = "ntex-codec"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
authors = ["Nikolay Kim <fafhrd91@gmail.com>"]
55
description = "Utilities for encoding and decoding frames"
66
keywords = ["network", "framework", "async", "futures"]
77
homepage = "https://ntex.rs"
88
repository = "https://github.com/ntex-rs/ntex.git"
99
documentation = "https://docs.rs/ntex-codec/"
1010
categories = ["network-programming", "asynchronous"]
11-
license = "MIT/Apache-2.0"
11+
license = "MIT"
1212
edition = "2018"
1313

1414
[lib]
@@ -20,6 +20,10 @@ bitflags = "1.2.1"
2020
bytes = "0.5.4"
2121
futures-core = "0.3.4"
2222
futures-sink = "0.3.4"
23-
tokio = { version = "0.2.4", default-features=false }
23+
tokio = { version = "0.2.6", default-features=false }
2424
tokio-util = { version = "0.2.0", default-features=false, features=["codec"] }
25-
log = "0.4"
25+
log = "0.4"
26+
27+
[dev-dependencies]
28+
ntex = "0.1.4"
29+
futures = "0.3.4"

ntex-codec/src/framed.rs

Lines changed: 211 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@ const HW: usize = 8 * 1024;
1313

1414
bitflags::bitflags! {
1515
struct Flags: u8 {
16-
const EOF = 0b0001;
17-
const READABLE = 0b0010;
16+
const EOF = 0b0001;
17+
const READABLE = 0b0010;
18+
const DISCONNECTED = 0b0100;
19+
const SHUTDOWN = 0b1000;
1820
}
1921
}
2022

2123
/// A unified `Stream` and `Sink` interface to an underlying I/O object, using
2224
/// the `Encoder` and `Decoder` traits to encode and decode frames.
25+
/// `Framed` is heavily optimized for streaming io.
2326
pub struct Framed<T, U> {
2427
io: T,
2528
codec: U,
@@ -28,8 +31,6 @@ pub struct Framed<T, U> {
2831
write_buf: BytesMut,
2932
}
3033

31-
impl<T, U> Unpin for Framed<T, U> {}
32-
3334
impl<T, U> Framed<T, U>
3435
where
3536
T: AsyncRead + AsyncWrite,
@@ -123,6 +124,18 @@ impl<T, U> Framed<T, U> {
123124
&mut self.io
124125
}
125126

127+
#[inline]
128+
/// Get read buffer.
129+
pub fn read_buf_mut(&mut self) -> &mut BytesMut {
130+
&mut self.read_buf
131+
}
132+
133+
#[inline]
134+
/// Get write buffer.
135+
pub fn write_buf_mut(&mut self) -> &mut BytesMut {
136+
&mut self.write_buf
137+
}
138+
126139
#[inline]
127140
/// Check if write buffer is empty.
128141
pub fn is_write_buf_empty(&self) -> bool {
@@ -135,6 +148,12 @@ impl<T, U> Framed<T, U> {
135148
self.write_buf.len() >= HW
136149
}
137150

151+
#[inline]
152+
/// Check if framed object is closed
153+
pub fn is_closed(&self) -> bool {
154+
self.flags.contains(Flags::DISCONNECTED)
155+
}
156+
138157
#[inline]
139158
/// Consume the `Frame`, returning `Frame` with different codec.
140159
pub fn into_framed<U2>(self, codec: U2) -> Framed<T, U2> {
@@ -227,34 +246,87 @@ where
227246
pub fn flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
228247
log::trace!("flushing framed transport");
229248

230-
while !self.write_buf.is_empty() {
231-
log::trace!("writing; remaining={}", self.write_buf.len());
249+
let len = self.write_buf.len();
250+
if len == 0 {
251+
return Poll::Ready(Ok(()));
252+
}
232253

233-
let n = ready!(Pin::new(&mut self.io).poll_write(cx, &self.write_buf))?;
234-
if n == 0 {
235-
return Poll::Ready(Err(io::Error::new(
236-
io::ErrorKind::WriteZero,
237-
"failed to write frame to transport",
238-
)
239-
.into()));
254+
let mut written = 0;
255+
while written < len {
256+
match Pin::new(&mut self.io).poll_write(cx, &self.write_buf[written..]) {
257+
Poll::Pending => break,
258+
Poll::Ready(Ok(n)) => {
259+
if n == 0 {
260+
log::trace!("Disconnected during flush, written {}", written);
261+
self.flags.insert(Flags::DISCONNECTED);
262+
return Poll::Ready(Err(io::Error::new(
263+
io::ErrorKind::WriteZero,
264+
"failed to write frame to transport",
265+
)
266+
.into()));
267+
} else {
268+
written += n
269+
}
270+
}
271+
Poll::Ready(Err(e)) => {
272+
log::trace!("Error during flush: {}", e);
273+
self.flags.insert(Flags::DISCONNECTED);
274+
return Poll::Ready(Err(e.into()));
275+
}
240276
}
241-
242-
// remove written data
243-
self.write_buf.advance(n);
244277
}
245278

246-
// Try flushing the underlying IO
247-
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
248-
249-
log::trace!("framed transport flushed");
250-
Poll::Ready(Ok(()))
279+
// remove written data
280+
if written == len {
281+
// flushed same amount as in buffer, we dont need to reallocate
282+
unsafe { self.write_buf.set_len(0) }
283+
} else {
284+
self.write_buf.advance(written);
285+
}
286+
if self.write_buf.is_empty() {
287+
Poll::Ready(Ok(()))
288+
} else {
289+
Poll::Pending
290+
}
251291
}
292+
}
252293

294+
impl<T, U> Framed<T, U>
295+
where
296+
T: AsyncRead + AsyncWrite + Unpin,
297+
{
253298
#[inline]
254299
/// Flush write buffer and shutdown underlying I/O stream.
255-
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), U::Error>> {
256-
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
257-
ready!(Pin::new(&mut self.io).poll_shutdown(cx))?;
300+
///
301+
/// Close method shutdown write side of a io object and
302+
/// then reads until disconnect or error, high level code must use
303+
/// timeout for close operation.
304+
pub fn close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
305+
if !self.flags.contains(Flags::DISCONNECTED) {
306+
// flush write buffer
307+
ready!(Pin::new(&mut self.io).poll_flush(cx))?;
308+
309+
if !self.flags.contains(Flags::SHUTDOWN) {
310+
// shutdown WRITE side
311+
ready!(Pin::new(&mut self.io).poll_shutdown(cx)).map_err(|e| {
312+
self.flags.insert(Flags::DISCONNECTED);
313+
e
314+
})?;
315+
self.flags.insert(Flags::SHUTDOWN);
316+
}
317+
318+
// read until 0 or err
319+
let mut buf = [0u8; 512];
320+
loop {
321+
match ready!(Pin::new(&mut self.io).poll_read(cx, &mut buf)) {
322+
Err(_) | Ok(0) => {
323+
break;
324+
}
325+
_ => (),
326+
}
327+
}
328+
self.flags.insert(Flags::DISCONNECTED);
329+
}
258330
log::trace!("framed transport flushed and closed");
259331
Poll::Ready(Ok(()))
260332
}
@@ -269,11 +341,9 @@ where
269341
pub fn next_item(
270342
&mut self,
271343
cx: &mut Context<'_>,
272-
) -> Poll<Option<Result<U::Item, U::Error>>>
273-
where
274-
T: AsyncRead,
275-
U: Decoder,
276-
{
344+
) -> Poll<Option<Result<U::Item, U::Error>>> {
345+
let mut done_read = false;
346+
277347
loop {
278348
// Repeatedly call `decode` or `decode_eof` as long as it is
279349
// "readable". Readable is defined as not having returned `None`. If
@@ -302,34 +372,53 @@ where
302372
}
303373

304374
self.flags.remove(Flags::READABLE);
375+
if done_read {
376+
return Poll::Pending;
377+
}
305378
}
306379

307380
debug_assert!(!self.flags.contains(Flags::EOF));
308381

309-
// Otherwise, try to read more data and try again. Make sure we've got room
310-
let remaining = self.read_buf.capacity() - self.read_buf.len();
311-
if remaining < LW {
312-
self.read_buf.reserve(HW - remaining)
313-
}
314-
let cnt = match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf)
315-
{
316-
Poll::Pending => return Poll::Pending,
317-
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
318-
Poll::Ready(Ok(cnt)) => cnt,
319-
};
320-
321-
if cnt == 0 {
322-
self.flags.insert(Flags::EOF);
382+
// read all data from socket
383+
let mut updated = false;
384+
loop {
385+
// Otherwise, try to read more data and try again. Make sure we've got room
386+
let remaining = self.read_buf.capacity() - self.read_buf.len();
387+
if remaining < LW {
388+
self.read_buf.reserve(HW - remaining)
389+
}
390+
match Pin::new(&mut self.io).poll_read_buf(cx, &mut self.read_buf) {
391+
Poll::Pending => {
392+
if updated {
393+
done_read = true;
394+
self.flags.insert(Flags::READABLE);
395+
break;
396+
} else {
397+
return Poll::Pending;
398+
}
399+
}
400+
Poll::Ready(Ok(n)) => {
401+
if n == 0 {
402+
self.flags.insert(Flags::EOF | Flags::READABLE);
403+
if updated {
404+
done_read = true;
405+
}
406+
break;
407+
} else {
408+
updated = true;
409+
}
410+
}
411+
Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
412+
}
323413
}
324-
self.flags.insert(Flags::READABLE);
325414
}
326415
}
327416
}
328417

329418
impl<T, U> Stream for Framed<T, U>
330419
where
331420
T: AsyncRead + Unpin,
332-
U: Decoder,
421+
U: Decoder + Unpin,
333422
{
334423
type Item = Result<U::Item, U::Error>;
335424

@@ -344,8 +433,8 @@ where
344433

345434
impl<T, U> Sink<U::Item> for Framed<T, U>
346435
where
347-
T: AsyncWrite + Unpin,
348-
U: Encoder,
436+
T: AsyncRead + AsyncWrite + Unpin,
437+
U: Encoder + Unpin,
349438
U::Error: From<io::Error>,
350439
{
351440
type Error = U::Error;
@@ -383,7 +472,7 @@ where
383472
mut self: Pin<&mut Self>,
384473
cx: &mut Context<'_>,
385474
) -> Poll<Result<(), Self::Error>> {
386-
self.close(cx)
475+
self.close(cx).map_err(|e| e.into())
387476
}
388477
}
389478

@@ -443,3 +532,77 @@ impl<T, U> FramedParts<T, U> {
443532
}
444533
}
445534
}
535+
536+
#[cfg(test)]
537+
mod tests {
538+
use bytes::Bytes;
539+
use futures::future::lazy;
540+
use futures::Sink;
541+
use ntex::testing::Io;
542+
543+
use super::*;
544+
use crate::BytesCodec;
545+
546+
#[ntex::test]
547+
async fn test_sink() {
548+
let (client, server) = Io::create();
549+
client.remote_buffer_cap(1024);
550+
let mut server = Framed::new(server, BytesCodec);
551+
552+
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
553+
.await
554+
.is_ready());
555+
556+
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
557+
Pin::new(&mut server).start_send(data).unwrap();
558+
assert_eq!(client.read_any(), b"".as_ref());
559+
560+
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
561+
.await
562+
.is_ready());
563+
assert_eq!(client.read_any(), b"GET /test HTTP/1.1\r\n\r\n".as_ref());
564+
565+
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
566+
.await
567+
.is_pending());
568+
client.close().await;
569+
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
570+
.await
571+
.is_ready());
572+
assert!(client.is_closed());
573+
}
574+
575+
#[ntex::test]
576+
async fn test_write_pending() {
577+
let (client, server) = Io::create();
578+
let mut server = Framed::new(server, BytesCodec);
579+
580+
assert!(lazy(|cx| Pin::new(&mut server).poll_ready(cx))
581+
.await
582+
.is_ready());
583+
let data = Bytes::from_static(b"GET /test HTTP/1.1\r\n\r\n");
584+
Pin::new(&mut server).start_send(data).unwrap();
585+
586+
client.remote_buffer_cap(3);
587+
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
588+
.await
589+
.is_pending());
590+
assert_eq!(client.read_any(), b"GET".as_ref());
591+
592+
client.remote_buffer_cap(1024);
593+
assert!(lazy(|cx| Pin::new(&mut server).poll_flush(cx))
594+
.await
595+
.is_ready());
596+
assert_eq!(client.read_any(), b" /test HTTP/1.1\r\n\r\n".as_ref());
597+
598+
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
599+
.await
600+
.is_pending());
601+
client.close().await;
602+
assert!(lazy(|cx| Pin::new(&mut server).poll_close(cx))
603+
.await
604+
.is_ready());
605+
assert!(client.is_closed());
606+
assert!(server.is_closed());
607+
}
608+
}

0 commit comments

Comments
 (0)