@@ -13,13 +13,16 @@ const HW: usize = 8 * 1024;
1313
1414bitflags:: bitflags! {
1515 struct Flags : u8 {
16- const EOF = 0b0001 ;
17- const READABLE = 0b0010 ;
16+ const EOF = 0b0001 ;
17+ const READABLE = 0b0010 ;
18+ const DISCONNECTED = 0b0100 ;
19+ const SHUTDOWN = 0b1000 ;
1820 }
1921}
2022
2123/// A unified `Stream` and `Sink` interface to an underlying I/O object, using
2224/// the `Encoder` and `Decoder` traits to encode and decode frames.
25+ /// `Framed` is heavily optimized for streaming io.
2326pub struct Framed < T , U > {
2427 io : T ,
2528 codec : U ,
@@ -28,8 +31,6 @@ pub struct Framed<T, U> {
2831 write_buf : BytesMut ,
2932}
3033
31- impl < T , U > Unpin for Framed < T , U > { }
32-
3334impl < T , U > Framed < T , U >
3435where
3536 T : AsyncRead + AsyncWrite ,
@@ -123,6 +124,18 @@ impl<T, U> Framed<T, U> {
123124 & mut self . io
124125 }
125126
127+ #[ inline]
128+ /// Get read buffer.
129+ pub fn read_buf_mut ( & mut self ) -> & mut BytesMut {
130+ & mut self . read_buf
131+ }
132+
133+ #[ inline]
134+ /// Get write buffer.
135+ pub fn write_buf_mut ( & mut self ) -> & mut BytesMut {
136+ & mut self . write_buf
137+ }
138+
126139 #[ inline]
127140 /// Check if write buffer is empty.
128141 pub fn is_write_buf_empty ( & self ) -> bool {
@@ -135,6 +148,12 @@ impl<T, U> Framed<T, U> {
135148 self . write_buf . len ( ) >= HW
136149 }
137150
151+ #[ inline]
152+ /// Check if framed object is closed
153+ pub fn is_closed ( & self ) -> bool {
154+ self . flags . contains ( Flags :: DISCONNECTED )
155+ }
156+
138157 #[ inline]
139158 /// Consume the `Frame`, returning `Frame` with different codec.
140159 pub fn into_framed < U2 > ( self , codec : U2 ) -> Framed < T , U2 > {
@@ -227,34 +246,87 @@ where
227246 pub fn flush ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , U :: Error > > {
228247 log:: trace!( "flushing framed transport" ) ;
229248
230- while !self . write_buf . is_empty ( ) {
231- log:: trace!( "writing; remaining={}" , self . write_buf. len( ) ) ;
249+ let len = self . write_buf . len ( ) ;
250+ if len == 0 {
251+ return Poll :: Ready ( Ok ( ( ) ) ) ;
252+ }
232253
233- let n = ready ! ( Pin :: new( & mut self . io) . poll_write( cx, & self . write_buf) ) ?;
234- if n == 0 {
235- return Poll :: Ready ( Err ( io:: Error :: new (
236- io:: ErrorKind :: WriteZero ,
237- "failed to write frame to transport" ,
238- )
239- . into ( ) ) ) ;
254+ let mut written = 0 ;
255+ while written < len {
256+ match Pin :: new ( & mut self . io ) . poll_write ( cx, & self . write_buf [ written..] ) {
257+ Poll :: Pending => break ,
258+ Poll :: Ready ( Ok ( n) ) => {
259+ if n == 0 {
260+ log:: trace!( "Disconnected during flush, written {}" , written) ;
261+ self . flags . insert ( Flags :: DISCONNECTED ) ;
262+ return Poll :: Ready ( Err ( io:: Error :: new (
263+ io:: ErrorKind :: WriteZero ,
264+ "failed to write frame to transport" ,
265+ )
266+ . into ( ) ) ) ;
267+ } else {
268+ written += n
269+ }
270+ }
271+ Poll :: Ready ( Err ( e) ) => {
272+ log:: trace!( "Error during flush: {}" , e) ;
273+ self . flags . insert ( Flags :: DISCONNECTED ) ;
274+ return Poll :: Ready ( Err ( e. into ( ) ) ) ;
275+ }
240276 }
241-
242- // remove written data
243- self . write_buf . advance ( n) ;
244277 }
245278
246- // Try flushing the underlying IO
247- ready ! ( Pin :: new( & mut self . io) . poll_flush( cx) ) ?;
248-
249- log:: trace!( "framed transport flushed" ) ;
250- Poll :: Ready ( Ok ( ( ) ) )
279+ // remove written data
280+ if written == len {
281+ // flushed same amount as in buffer, we dont need to reallocate
282+ unsafe { self . write_buf . set_len ( 0 ) }
283+ } else {
284+ self . write_buf . advance ( written) ;
285+ }
286+ if self . write_buf . is_empty ( ) {
287+ Poll :: Ready ( Ok ( ( ) ) )
288+ } else {
289+ Poll :: Pending
290+ }
251291 }
292+ }
252293
294+ impl < T , U > Framed < T , U >
295+ where
296+ T : AsyncRead + AsyncWrite + Unpin ,
297+ {
253298 #[ inline]
254299 /// Flush write buffer and shutdown underlying I/O stream.
255- pub fn close ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , U :: Error > > {
256- ready ! ( Pin :: new( & mut self . io) . poll_flush( cx) ) ?;
257- ready ! ( Pin :: new( & mut self . io) . poll_shutdown( cx) ) ?;
300+ ///
301+ /// Close method shutdown write side of a io object and
302+ /// then reads until disconnect or error, high level code must use
303+ /// timeout for close operation.
304+ pub fn close ( & mut self , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
305+ if !self . flags . contains ( Flags :: DISCONNECTED ) {
306+ // flush write buffer
307+ ready ! ( Pin :: new( & mut self . io) . poll_flush( cx) ) ?;
308+
309+ if !self . flags . contains ( Flags :: SHUTDOWN ) {
310+ // shutdown WRITE side
311+ ready ! ( Pin :: new( & mut self . io) . poll_shutdown( cx) ) . map_err ( |e| {
312+ self . flags . insert ( Flags :: DISCONNECTED ) ;
313+ e
314+ } ) ?;
315+ self . flags . insert ( Flags :: SHUTDOWN ) ;
316+ }
317+
318+ // read until 0 or err
319+ let mut buf = [ 0u8 ; 512 ] ;
320+ loop {
321+ match ready ! ( Pin :: new( & mut self . io) . poll_read( cx, & mut buf) ) {
322+ Err ( _) | Ok ( 0 ) => {
323+ break ;
324+ }
325+ _ => ( ) ,
326+ }
327+ }
328+ self . flags . insert ( Flags :: DISCONNECTED ) ;
329+ }
258330 log:: trace!( "framed transport flushed and closed" ) ;
259331 Poll :: Ready ( Ok ( ( ) ) )
260332 }
@@ -269,11 +341,9 @@ where
269341 pub fn next_item (
270342 & mut self ,
271343 cx : & mut Context < ' _ > ,
272- ) -> Poll < Option < Result < U :: Item , U :: Error > > >
273- where
274- T : AsyncRead ,
275- U : Decoder ,
276- {
344+ ) -> Poll < Option < Result < U :: Item , U :: Error > > > {
345+ let mut done_read = false ;
346+
277347 loop {
278348 // Repeatedly call `decode` or `decode_eof` as long as it is
279349 // "readable". Readable is defined as not having returned `None`. If
@@ -302,34 +372,53 @@ where
302372 }
303373
304374 self . flags . remove ( Flags :: READABLE ) ;
375+ if done_read {
376+ return Poll :: Pending ;
377+ }
305378 }
306379
307380 debug_assert ! ( !self . flags. contains( Flags :: EOF ) ) ;
308381
309- // Otherwise, try to read more data and try again. Make sure we've got room
310- let remaining = self . read_buf . capacity ( ) - self . read_buf . len ( ) ;
311- if remaining < LW {
312- self . read_buf . reserve ( HW - remaining)
313- }
314- let cnt = match Pin :: new ( & mut self . io ) . poll_read_buf ( cx, & mut self . read_buf )
315- {
316- Poll :: Pending => return Poll :: Pending ,
317- Poll :: Ready ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e. into ( ) ) ) ) ,
318- Poll :: Ready ( Ok ( cnt) ) => cnt,
319- } ;
320-
321- if cnt == 0 {
322- self . flags . insert ( Flags :: EOF ) ;
382+ // read all data from socket
383+ let mut updated = false ;
384+ loop {
385+ // Otherwise, try to read more data and try again. Make sure we've got room
386+ let remaining = self . read_buf . capacity ( ) - self . read_buf . len ( ) ;
387+ if remaining < LW {
388+ self . read_buf . reserve ( HW - remaining)
389+ }
390+ match Pin :: new ( & mut self . io ) . poll_read_buf ( cx, & mut self . read_buf ) {
391+ Poll :: Pending => {
392+ if updated {
393+ done_read = true ;
394+ self . flags . insert ( Flags :: READABLE ) ;
395+ break ;
396+ } else {
397+ return Poll :: Pending ;
398+ }
399+ }
400+ Poll :: Ready ( Ok ( n) ) => {
401+ if n == 0 {
402+ self . flags . insert ( Flags :: EOF | Flags :: READABLE ) ;
403+ if updated {
404+ done_read = true ;
405+ }
406+ break ;
407+ } else {
408+ updated = true ;
409+ }
410+ }
411+ Poll :: Ready ( Err ( e) ) => return Poll :: Ready ( Some ( Err ( e. into ( ) ) ) ) ,
412+ }
323413 }
324- self . flags . insert ( Flags :: READABLE ) ;
325414 }
326415 }
327416}
328417
329418impl < T , U > Stream for Framed < T , U >
330419where
331420 T : AsyncRead + Unpin ,
332- U : Decoder ,
421+ U : Decoder + Unpin ,
333422{
334423 type Item = Result < U :: Item , U :: Error > ;
335424
@@ -344,8 +433,8 @@ where
344433
345434impl < T , U > Sink < U :: Item > for Framed < T , U >
346435where
347- T : AsyncWrite + Unpin ,
348- U : Encoder ,
436+ T : AsyncRead + AsyncWrite + Unpin ,
437+ U : Encoder + Unpin ,
349438 U :: Error : From < io:: Error > ,
350439{
351440 type Error = U :: Error ;
@@ -383,7 +472,7 @@ where
383472 mut self : Pin < & mut Self > ,
384473 cx : & mut Context < ' _ > ,
385474 ) -> Poll < Result < ( ) , Self :: Error > > {
386- self . close ( cx)
475+ self . close ( cx) . map_err ( |e| e . into ( ) )
387476 }
388477}
389478
@@ -443,3 +532,77 @@ impl<T, U> FramedParts<T, U> {
443532 }
444533 }
445534}
535+
536+ #[ cfg( test) ]
537+ mod tests {
538+ use bytes:: Bytes ;
539+ use futures:: future:: lazy;
540+ use futures:: Sink ;
541+ use ntex:: testing:: Io ;
542+
543+ use super :: * ;
544+ use crate :: BytesCodec ;
545+
546+ #[ ntex:: test]
547+ async fn test_sink ( ) {
548+ let ( client, server) = Io :: create ( ) ;
549+ client. remote_buffer_cap ( 1024 ) ;
550+ let mut server = Framed :: new ( server, BytesCodec ) ;
551+
552+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_ready( cx) )
553+ . await
554+ . is_ready( ) ) ;
555+
556+ let data = Bytes :: from_static ( b"GET /test HTTP/1.1\r \n \r \n " ) ;
557+ Pin :: new ( & mut server) . start_send ( data) . unwrap ( ) ;
558+ assert_eq ! ( client. read_any( ) , b"" . as_ref( ) ) ;
559+
560+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_flush( cx) )
561+ . await
562+ . is_ready( ) ) ;
563+ assert_eq ! ( client. read_any( ) , b"GET /test HTTP/1.1\r \n \r \n " . as_ref( ) ) ;
564+
565+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_close( cx) )
566+ . await
567+ . is_pending( ) ) ;
568+ client. close ( ) . await ;
569+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_close( cx) )
570+ . await
571+ . is_ready( ) ) ;
572+ assert ! ( client. is_closed( ) ) ;
573+ }
574+
575+ #[ ntex:: test]
576+ async fn test_write_pending ( ) {
577+ let ( client, server) = Io :: create ( ) ;
578+ let mut server = Framed :: new ( server, BytesCodec ) ;
579+
580+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_ready( cx) )
581+ . await
582+ . is_ready( ) ) ;
583+ let data = Bytes :: from_static ( b"GET /test HTTP/1.1\r \n \r \n " ) ;
584+ Pin :: new ( & mut server) . start_send ( data) . unwrap ( ) ;
585+
586+ client. remote_buffer_cap ( 3 ) ;
587+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_flush( cx) )
588+ . await
589+ . is_pending( ) ) ;
590+ assert_eq ! ( client. read_any( ) , b"GET" . as_ref( ) ) ;
591+
592+ client. remote_buffer_cap ( 1024 ) ;
593+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_flush( cx) )
594+ . await
595+ . is_ready( ) ) ;
596+ assert_eq ! ( client. read_any( ) , b" /test HTTP/1.1\r \n \r \n " . as_ref( ) ) ;
597+
598+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_close( cx) )
599+ . await
600+ . is_pending( ) ) ;
601+ client. close ( ) . await ;
602+ assert ! ( lazy( |cx| Pin :: new( & mut server) . poll_close( cx) )
603+ . await
604+ . is_ready( ) ) ;
605+ assert ! ( client. is_closed( ) ) ;
606+ assert ! ( server. is_closed( ) ) ;
607+ }
608+ }
0 commit comments