1
1
use std:: io:: { Error , ErrorKind , IoSlice , Result } ;
2
2
use std:: pin:: Pin ;
3
3
use std:: ptr;
4
- use std:: sync:: Arc ;
5
4
use std:: task:: { Context , Poll , RawWaker , RawWakerVTable , Waker } ;
6
5
use std:: time:: Duration ;
7
6
8
7
use bytes:: buf:: BufMut ;
9
8
use ignore_result:: Ignore ;
10
- use rustls:: pki_types:: ServerName ;
11
- use rustls:: ClientConfig ;
12
9
use tokio:: io:: { AsyncBufReadExt , AsyncRead , AsyncWrite , AsyncWriteExt , BufStream , ReadBuf } ;
13
10
use tokio:: net:: TcpStream ;
14
11
use tokio:: { select, time} ;
15
- use tokio_rustls:: client:: TlsStream ;
16
- use tokio_rustls:: TlsConnector ;
17
12
use tracing:: { debug, trace} ;
18
13
14
+ #[ cfg( feature = "tls" ) ]
15
+ mod tls {
16
+ pub use std:: sync:: Arc ;
17
+
18
+ pub use rustls:: pki_types:: ServerName ;
19
+ pub use rustls:: ClientConfig ;
20
+ pub use tokio_rustls:: client:: TlsStream ;
21
+ pub use tokio_rustls:: TlsConnector ;
22
+ }
23
+ #[ cfg( feature = "tls" ) ]
24
+ use tls:: * ;
25
+
19
26
use crate :: deadline:: Deadline ;
20
27
use crate :: endpoint:: { EndpointRef , IterableEndpoints } ;
21
28
22
29
const NOOP_VTABLE : RawWakerVTable =
23
30
RawWakerVTable :: new ( |_| RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) , |_| { } , |_| { } , |_| { } ) ;
24
31
const NOOP_WAKER : RawWaker = RawWaker :: new ( ptr:: null ( ) , & NOOP_VTABLE ) ;
25
32
33
+ #[ derive( Debug ) ]
26
34
pub enum Connection {
27
- Tls ( TlsStream < TcpStream > ) ,
28
35
Raw ( TcpStream ) ,
36
+ #[ cfg( feature = "tls" ) ]
37
+ Tls ( TlsStream < TcpStream > ) ,
29
38
}
30
39
31
40
impl AsyncRead for Connection {
32
41
fn poll_read ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & mut ReadBuf < ' _ > ) -> Poll < Result < ( ) > > {
33
42
match self . get_mut ( ) {
34
43
Self :: Raw ( stream) => Pin :: new ( stream) . poll_read ( cx, buf) ,
44
+ #[ cfg( feature = "tls" ) ]
35
45
Self :: Tls ( stream) => Pin :: new ( stream) . poll_read ( cx, buf) ,
36
46
}
37
47
}
@@ -41,20 +51,23 @@ impl AsyncWrite for Connection {
41
51
fn poll_write ( self : Pin < & mut Self > , cx : & mut Context < ' _ > , buf : & [ u8 ] ) -> Poll < Result < usize > > {
42
52
match self . get_mut ( ) {
43
53
Self :: Raw ( stream) => Pin :: new ( stream) . poll_write ( cx, buf) ,
54
+ #[ cfg( feature = "tls" ) ]
44
55
Self :: Tls ( stream) => Pin :: new ( stream) . poll_write ( cx, buf) ,
45
56
}
46
57
}
47
58
48
59
fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
49
60
match self . get_mut ( ) {
50
61
Self :: Raw ( stream) => Pin :: new ( stream) . poll_flush ( cx) ,
62
+ #[ cfg( feature = "tls" ) ]
51
63
Self :: Tls ( stream) => Pin :: new ( stream) . poll_flush ( cx) ,
52
64
}
53
65
}
54
66
55
67
fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) > > {
56
68
match self . get_mut ( ) {
57
69
Self :: Raw ( stream) => Pin :: new ( stream) . poll_shutdown ( cx) ,
70
+ #[ cfg( feature = "tls" ) ]
58
71
Self :: Tls ( stream) => Pin :: new ( stream) . poll_shutdown ( cx) ,
59
72
}
60
73
}
@@ -65,6 +78,7 @@ impl Connection {
65
78
Self :: Raw ( stream)
66
79
}
67
80
81
+ #[ cfg( feature = "tls" ) ]
68
82
pub fn new_tls ( stream : TlsStream < TcpStream > ) -> Self {
69
83
Self :: Tls ( stream)
70
84
}
@@ -97,6 +111,7 @@ impl Connection {
97
111
pub async fn readable ( & self ) -> Result < ( ) > {
98
112
match self {
99
113
Self :: Raw ( stream) => stream. readable ( ) . await ,
114
+ #[ cfg( feature = "tls" ) ]
100
115
Self :: Tls ( stream) => {
101
116
let ( stream, session) = stream. get_ref ( ) ;
102
117
if session. wants_read ( ) {
@@ -112,6 +127,7 @@ impl Connection {
112
127
pub async fn writable ( & self ) -> Result < ( ) > {
113
128
match self {
114
129
Self :: Raw ( stream) => stream. writable ( ) . await ,
130
+ #[ cfg( feature = "tls" ) ]
115
131
Self :: Tls ( stream) => {
116
132
let ( stream, _session) = stream. get_ref ( ) ;
117
133
stream. writable ( ) . await
@@ -122,6 +138,7 @@ impl Connection {
122
138
pub fn wants_write ( & self ) -> bool {
123
139
match self {
124
140
Self :: Raw ( _) => false ,
141
+ #[ cfg( feature = "tls" ) ]
125
142
Self :: Tls ( stream) => {
126
143
let ( _stream, session) = stream. get_ref ( ) ;
127
144
session. wants_write ( )
@@ -160,13 +177,33 @@ impl Connection {
160
177
161
178
#[ derive( Clone ) ]
162
179
pub struct Connector {
163
- tls : TlsConnector ,
180
+ #[ cfg( feature = "tls" ) ]
181
+ tls : Option < TlsConnector > ,
164
182
timeout : Duration ,
165
183
}
166
184
167
185
impl Connector {
168
- pub fn new ( config : impl Into < Arc < ClientConfig > > ) -> Self {
169
- Self { tls : TlsConnector :: from ( config. into ( ) ) , timeout : Duration :: from_secs ( 10 ) }
186
+ #[ cfg( feature = "tls" ) ]
187
+ #[ allow( dead_code) ]
188
+ pub fn new ( ) -> Self {
189
+ Self { tls : None , timeout : Duration :: from_secs ( 10 ) }
190
+ }
191
+
192
+ #[ cfg( not( feature = "tls" ) ) ]
193
+ pub fn new ( ) -> Self {
194
+ Self { timeout : Duration :: from_secs ( 10 ) }
195
+ }
196
+
197
+ #[ cfg( feature = "tls" ) ]
198
+ pub fn with_tls ( config : ClientConfig ) -> Self {
199
+ Self { tls : Some ( TlsConnector :: from ( Arc :: new ( config) ) ) , timeout : Duration :: from_secs ( 10 ) }
200
+ }
201
+
202
+ #[ cfg( feature = "tls" ) ]
203
+ async fn connect_tls ( & self , stream : TcpStream , host : & str ) -> Result < Connection > {
204
+ let domain = ServerName :: try_from ( host) . unwrap ( ) . to_owned ( ) ;
205
+ let stream = self . tls . as_ref ( ) . unwrap ( ) . connect ( domain, stream) . await ?;
206
+ Ok ( Connection :: new_tls ( stream) )
170
207
}
171
208
172
209
pub fn timeout ( & self ) -> Duration {
@@ -178,6 +215,14 @@ impl Connector {
178
215
}
179
216
180
217
pub async fn connect ( & self , endpoint : EndpointRef < ' _ > , deadline : & mut Deadline ) -> Result < Connection > {
218
+ if endpoint. tls {
219
+ #[ cfg( feature = "tls" ) ]
220
+ if self . tls . is_none ( ) {
221
+ return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not supported" ) ) ;
222
+ }
223
+ #[ cfg( not( feature = "tls" ) ) ]
224
+ return Err ( Error :: new ( ErrorKind :: Unsupported , "tls not supported" ) ) ;
225
+ }
181
226
select ! {
182
227
_ = unsafe { Pin :: new_unchecked( deadline) } => Err ( Error :: new( ErrorKind :: TimedOut , "deadline exceed" ) ) ,
183
228
_ = time:: sleep( self . timeout) => Err ( Error :: new( ErrorKind :: TimedOut , format!( "connection timeout{:?} exceed" , self . timeout) ) ) ,
@@ -186,9 +231,10 @@ impl Connector {
186
231
Err ( err) => Err ( err) ,
187
232
Ok ( sock) => {
188
233
let connection = if endpoint. tls {
189
- let domain = ServerName :: try_from( endpoint. host) . unwrap( ) . to_owned( ) ;
190
- let stream = self . tls. connect( domain, sock) . await ?;
191
- Connection :: new_tls( stream)
234
+ #[ cfg( not( feature = "tls" ) ) ]
235
+ unreachable!( "tls not supported" ) ;
236
+ #[ cfg( feature = "tls" ) ]
237
+ self . connect_tls( sock, endpoint. host) . await ?
192
238
} else {
193
239
Connection :: new_raw( sock)
194
240
} ;
@@ -231,3 +277,20 @@ impl Connector {
231
277
None
232
278
}
233
279
}
280
+
281
+ #[ cfg( test) ]
282
+ mod tests {
283
+ use std:: io:: ErrorKind ;
284
+
285
+ use super :: Connector ;
286
+ use crate :: deadline:: Deadline ;
287
+ use crate :: endpoint:: EndpointRef ;
288
+
289
+ #[ tokio:: test]
290
+ async fn raw ( ) {
291
+ let connector = Connector :: new ( ) ;
292
+ let endpoint = EndpointRef :: new ( "host1" , 2181 , true ) ;
293
+ let err = connector. connect ( endpoint, & mut Deadline :: never ( ) ) . await . unwrap_err ( ) ;
294
+ assert_eq ! ( err. kind( ) , ErrorKind :: Unsupported ) ;
295
+ }
296
+ }
0 commit comments