1
1
use std:: io:: { Error , ErrorKind , IoSlice , Result } ;
2
2
use std:: pin:: Pin ;
3
- use std:: ptr;
4
- use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
3
+ use std:: task:: { Context , Poll } ;
5
4
use std:: time:: Duration ;
6
5
7
6
use bytes:: buf:: BufMut ;
8
7
use ignore_result:: Ignore ;
9
- use tokio:: io:: { AsyncBufReadExt , AsyncRead , AsyncWrite , AsyncWriteExt , BufStream , ReadBuf } ;
8
+ use tokio:: io:: { AsyncBufReadExt , AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt , BufStream , ReadBuf } ;
10
9
use tokio:: net:: TcpStream ;
11
10
use tokio:: { select, time} ;
12
11
use tracing:: { debug, trace} ;
@@ -26,17 +25,31 @@ use tls::*;
26
25
use crate :: deadline:: Deadline ;
27
26
use crate :: endpoint:: { EndpointRef , IterableEndpoints } ;
28
27
29
- const NOOP_VTABLE : RawWakerVTable =
30
- RawWakerVTable :: new ( |_| RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) , |_| { } , |_| { } , |_| { } ) ;
31
- const NOOP_WAKER : RawWaker = RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) ;
32
-
33
28
#[ derive( Debug ) ]
34
29
pub enum Connection {
35
30
Raw ( TcpStream ) ,
36
31
#[ cfg( feature = "tls" ) ]
37
32
Tls ( TlsStream < TcpStream > ) ,
38
33
}
39
34
35
+ pub trait AsyncReadToBuf : AsyncReadExt {
36
+ async fn read_to_buf ( & mut self , buf : & mut impl BufMut ) -> Result < usize >
37
+ where
38
+ Self : Unpin , {
39
+ let chunk = buf. chunk_mut ( ) ;
40
+ let read_to = unsafe { std:: mem:: transmute ( chunk. as_uninit_slice_mut ( ) ) } ;
41
+ let n = self . read ( read_to) . await ?;
42
+ if n != 0 {
43
+ unsafe {
44
+ buf. advance_mut ( n) ;
45
+ }
46
+ }
47
+ Ok ( n)
48
+ }
49
+ }
50
+
51
+ impl < T > AsyncReadToBuf for T where T : AsyncReadExt { }
52
+
40
53
impl AsyncRead for Connection {
41
54
fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < Result < ( ) > > {
42
55
match self . get_mut ( ) {
@@ -56,6 +69,14 @@ impl AsyncWrite for Connection {
56
69
}
57
70
}
58
71
72
+ fn poll_write_vectored ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , bufs : & [ IoSlice < ' _ > ] ) -> Poll < Result < usize > > {
73
+ match self . get_mut ( ) {
74
+ Self :: Raw ( stream) => Pin :: new ( stream) . poll_write_vectored ( cx, bufs) ,
75
+ #[ cfg( feature = "tls" ) ]
76
+ Self :: Tls ( stream) => Pin :: new ( stream) . poll_write_vectored ( cx, bufs) ,
77
+ }
78
+ }
79
+
59
80
fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
60
81
match self . get_mut ( ) {
61
82
Self :: Raw ( stream) => Pin :: new ( stream) . poll_flush ( cx) ,
@@ -73,86 +94,52 @@ impl AsyncWrite for Connection {
73
94
}
74
95
}
75
96
76
- impl Connection {
77
- pub fn new_raw ( stream : TcpStream ) -> Self {
78
- Self :: Raw ( stream)
97
+ pub struct ConnReader < ' a > {
98
+ conn : & ' a mut Connection ,
99
+ }
100
+
101
+ impl AsyncRead for ConnReader < ' _ > {
102
+ fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < Result < ( ) > > {
103
+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_read ( cx, buf)
79
104
}
105
+ }
80
106
81
- #[ cfg( feature = "tls" ) ]
82
- pub fn new_tls ( stream : TlsStream < TcpStream > ) -> Self {
83
- Self :: Tls ( stream)
107
+ pub struct ConnWriter < ' a > {
108
+ conn : & ' a mut Connection ,
109
+ }
110
+
111
+ impl AsyncWrite for ConnWriter < ' _ > {
112
+ fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize > > {
113
+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_write ( cx, buf)
84
114
}
85
115
86
- pub fn try_write_vectored ( & mut self , bufs : & [ IoSlice < ' _ > ] ) -> Result < usize > {
87
- let waker = unsafe { Waker :: from_raw ( NOOP_WAKER ) } ;
88
- let mut context = Context :: from_waker ( & waker) ;
89
- match Pin :: new ( self ) . poll_write_vectored ( & mut context, bufs) {
90
- Poll :: Pending => Err ( ErrorKind :: WouldBlock . into ( ) ) ,
91
- Poll :: Ready ( result) => result,
92
- }
116
+ fn poll_write_vectored ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , bufs : & [ IoSlice < ' _ > ] ) -> Poll < Result < usize > > {
117
+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_write_vectored ( cx, bufs)
93
118
}
94
119
95
- pub fn try_read_buf ( & mut self , buf : & mut impl BufMut ) -> Result < usize > {
96
- let waker = unsafe { Waker :: from_raw ( NOOP_WAKER ) } ;
97
- let mut context = Context :: from_waker ( & waker) ;
98
- let chunk = buf. chunk_mut ( ) ;
99
- let mut read_buf = unsafe { ReadBuf :: uninit ( chunk. as_uninit_slice_mut ( ) ) } ;
100
- match Pin :: new ( self ) . poll_read ( & mut context, & mut read_buf) {
101
- Poll :: Pending => Err ( ErrorKind :: WouldBlock . into ( ) ) ,
102
- Poll :: Ready ( Err ( err) ) => Err ( err) ,
103
- Poll :: Ready ( Ok ( ( ) ) ) => {
104
- let n = read_buf. filled ( ) . len ( ) ;
105
- unsafe { buf. advance_mut ( n) } ;
106
- Ok ( n)
107
- } ,
108
- }
120
+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
121
+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_flush ( cx)
109
122
}
110
123
111
- pub async fn readable ( & self ) -> Result < ( ) > {
112
- match self {
113
- Self :: Raw ( stream) => stream. readable ( ) . await ,
114
- #[ cfg( feature = "tls" ) ]
115
- Self :: Tls ( stream) => {
116
- let ( stream, session) = stream. get_ref ( ) ;
117
- if session. wants_read ( ) {
118
- stream. readable ( ) . await
119
- } else {
120
- // plaintext data are available for read
121
- std:: future:: ready ( Ok ( ( ) ) ) . await
122
- }
123
- } ,
124
- }
124
+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
125
+ Pin :: new ( & mut self . get_mut ( ) . conn ) . poll_shutdown ( cx)
125
126
}
127
+ }
126
128
127
- pub async fn writable ( & self ) -> Result < ( ) > {
128
- match self {
129
- Self :: Raw ( stream) => stream. writable ( ) . await ,
130
- #[ cfg( feature = "tls" ) ]
131
- Self :: Tls ( stream) => {
132
- let ( stream, _session) = stream. get_ref ( ) ;
133
- stream. writable ( ) . await
134
- } ,
135
- }
129
+ impl Connection {
130
+ pub fn new_raw ( stream : TcpStream ) -> Self {
131
+ Self :: Raw ( stream)
136
132
}
137
133
138
- pub fn wants_write ( & self ) -> bool {
139
- match self {
140
- Self :: Raw ( _) => false ,
141
- #[ cfg( feature = "tls" ) ]
142
- Self :: Tls ( stream) => {
143
- let ( _stream, session) = stream. get_ref ( ) ;
144
- session. wants_write ( )
145
- } ,
146
- }
134
+ pub fn split ( & mut self ) -> ( ConnReader < ' _ > , ConnWriter < ' _ > ) {
135
+ let reader = ConnReader { conn : self } ;
136
+ let writer = ConnWriter { conn : unsafe { std:: ptr:: read ( & reader. conn ) } } ;
137
+ ( reader, writer)
147
138
}
148
139
149
- pub fn try_flush ( & mut self ) -> Result < ( ) > {
150
- let waker = unsafe { Waker :: from_raw ( NOOP_WAKER ) } ;
151
- let mut context = Context :: from_waker ( & waker) ;
152
- match Pin :: new ( self ) . poll_flush ( & mut context) {
153
- Poll :: Pending => Err ( ErrorKind :: WouldBlock . into ( ) ) ,
154
- Poll :: Ready ( result) => result,
155
- }
140
+ #[ cfg( feature = "tls" ) ]
141
+ pub fn new_tls ( stream : TlsStream < TcpStream > ) -> Self {
142
+ Self :: Tls ( stream)
156
143
}
157
144
158
145
pub async fn command ( self , cmd : & str ) -> Result < String > {
0 commit comments