11//! An implementation of SSL streams for ntex backed by OpenSSL
2- use std:: io:: { self , Read as IoRead , Write as IoWrite } ;
3- use std:: { any, cell:: RefCell , future:: poll_fn, sync:: Arc , task:: ready, task:: Poll } ;
2+ use std:: { any, cell:: RefCell , io, sync:: Arc } ;
43
5- use ntex_bytes:: BufMut ;
6- use ntex_io:: { types, Filter , FilterLayer , Io , Layer , ReadBuf , WriteBuf } ;
4+ use ntex_io:: { Filter , FilterLayer , Io , Layer , ReadBuf , WriteBuf } ;
75use tls_rust:: { pki_types:: ServerName , ClientConfig , ClientConnection } ;
86
9- use super :: { PeerCert , PeerCertChain , Wrapper } ;
7+ use super :: Stream ;
108
119#[ derive( Debug ) ]
1210/// An implementation of SSL streams
@@ -16,110 +14,15 @@ pub struct TlsClientFilter {
1614
1715impl FilterLayer for TlsClientFilter {
1816 fn query ( & self , id : any:: TypeId ) -> Option < Box < dyn any:: Any > > {
19- const H2 : & [ u8 ] = b"h2" ;
20-
21- if id == any:: TypeId :: of :: < types:: HttpProtocol > ( ) {
22- let h2 = self
23- . session
24- . borrow ( )
25- . alpn_protocol ( )
26- . map ( |protos| protos. windows ( 2 ) . any ( |w| w == H2 ) )
27- . unwrap_or ( false ) ;
28-
29- let proto = if h2 {
30- types:: HttpProtocol :: Http2
31- } else {
32- types:: HttpProtocol :: Http1
33- } ;
34- Some ( Box :: new ( proto) )
35- } else if id == any:: TypeId :: of :: < PeerCert < ' _ > > ( ) {
36- if let Some ( cert_chain) = self . session . borrow ( ) . peer_certificates ( ) {
37- if let Some ( cert) = cert_chain. first ( ) {
38- Some ( Box :: new ( PeerCert ( cert. to_owned ( ) ) ) )
39- } else {
40- None
41- }
42- } else {
43- None
44- }
45- } else if id == any:: TypeId :: of :: < PeerCertChain < ' _ > > ( ) {
46- if let Some ( cert_chain) = self . session . borrow ( ) . peer_certificates ( ) {
47- Some ( Box :: new ( PeerCertChain ( cert_chain. to_vec ( ) ) ) )
48- } else {
49- None
50- }
51- } else {
52- None
53- }
17+ Stream :: new ( & mut * self . session . borrow_mut ( ) ) . query ( id)
5418 }
5519
5620 fn process_read_buf ( & self , buf : & ReadBuf < ' _ > ) -> io:: Result < usize > {
57- let mut session = self . session . borrow_mut ( ) ;
58- let mut new_bytes = 0 ;
59-
60- // get processed buffer
61- buf. with_src ( |src| {
62- if let Some ( src) = src {
63- buf. with_dst ( |dst| {
64- loop {
65- let mut cursor = io:: Cursor :: new ( & src) ;
66- let n = match session. read_tls ( & mut cursor) {
67- Ok ( n) => n,
68- Err ( ref err) if err. kind ( ) == io:: ErrorKind :: WouldBlock => {
69- break
70- }
71- Err ( err) => return Err ( err) ,
72- } ;
73- src. split_to ( n) ;
74- let state = session
75- . process_new_packets ( )
76- . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?;
77-
78- let new_b = state. plaintext_bytes_to_read ( ) ;
79- if new_b > 0 {
80- dst. reserve ( new_b) ;
81- let chunk: & mut [ u8 ] =
82- unsafe { std:: mem:: transmute ( & mut * dst. chunk_mut ( ) ) } ;
83- let v = session. reader ( ) . read ( chunk) ?;
84- unsafe { dst. advance_mut ( v) } ;
85- new_bytes += v;
86- } else {
87- break ;
88- }
89- }
90- Ok :: < _ , io:: Error > ( ( ) )
91- } ) ?;
92- }
93- Ok ( new_bytes)
94- } )
21+ Stream :: new ( & mut * self . session . borrow_mut ( ) ) . process_read_buf ( buf)
9522 }
9623
9724 fn process_write_buf ( & self , buf : & WriteBuf < ' _ > ) -> io:: Result < ( ) > {
98- buf. with_src ( |src| {
99- if let Some ( src) = src {
100- let mut io = Wrapper ( buf) ;
101- let mut session = self . session . borrow_mut ( ) ;
102-
103- ' outer: loop {
104- if !src. is_empty ( ) {
105- src. split_to ( session. writer ( ) . write ( src) ?) ;
106-
107- loop {
108- match session. write_tls ( & mut io) {
109- Ok ( 0 ) => continue ' outer,
110- Ok ( _) => continue ,
111- Err ( ref err) if err. kind ( ) == io:: ErrorKind :: WouldBlock => {
112- break
113- }
114- Err ( err) => return Err ( err) ,
115- }
116- }
117- }
118- break ;
119- }
120- }
121- Ok ( ( ) )
122- } )
25+ Stream :: new ( & mut * self . session . borrow_mut ( ) ) . process_write_buf ( buf)
12326 }
12427}
12528
@@ -130,76 +33,10 @@ impl TlsClientFilter {
13033 domain : ServerName < ' static > ,
13134 ) -> Result < Io < Layer < TlsClientFilter , F > > , io:: Error > {
13235 let session = ClientConnection :: new ( cfg, domain) . map_err ( io:: Error :: other) ?;
133- let filter = TlsClientFilter {
36+ let io = io . add_filter ( TlsClientFilter {
13437 session : RefCell :: new ( session) ,
135- } ;
136- let io = io. add_filter ( filter) ;
137-
138- let filter = io. filter ( ) ;
139- loop {
140- let ( result, handshaking) = io. with_buf ( |buf| {
141- let mut wrp = Wrapper ( buf) ;
142- let mut session = filter. session . borrow_mut ( ) ;
143- let mut result = Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "" ) ) ;
144-
145- while session. wants_write ( ) {
146- result = session. write_tls ( & mut wrp) . map ( |_| ( ) ) ;
147- if result. is_err ( ) {
148- break ;
149- }
150- }
151- if session. wants_read ( ) {
152- let has_data = buf. with_read_buf ( |rbuf| {
153- rbuf. with_src ( |b| {
154- b. as_ref ( ) . map ( |b| !b. is_empty ( ) ) . unwrap_or_default ( )
155- } )
156- } ) ;
157-
158- if has_data {
159- result = match session. read_tls ( & mut wrp) {
160- Ok ( 0 ) => Err ( io:: Error :: new (
161- io:: ErrorKind :: NotConnected ,
162- "disconnected" ,
163- ) ) ,
164- Ok ( _) => Ok ( ( ) ) ,
165- Err ( e) => Err ( e) ,
166- } ;
167-
168- session. process_new_packets ( ) . map_err ( |err| {
169- // In case we have an alert to send describing this error,
170- // try a last-gasp write -- but don't predate the primary
171- // error.
172- let _ = session. write_tls ( & mut wrp) ;
173- io:: Error :: new ( io:: ErrorKind :: InvalidData , err)
174- } ) ?;
175- } else {
176- result = Err ( io:: Error :: new ( io:: ErrorKind :: WouldBlock , "" ) ) ;
177- }
178- }
179-
180- Ok :: < _ , io:: Error > ( ( result, session. is_handshaking ( ) ) )
181- } ) ??;
182-
183- match result {
184- Ok ( ( ) ) => return Ok ( io) ,
185- Err ( ref e) if e. kind ( ) == io:: ErrorKind :: WouldBlock => {
186- if !handshaking {
187- return Ok ( io) ;
188- }
189- poll_fn ( |cx| {
190- match ready ! ( io. poll_read_notify( cx) ) ? {
191- Some ( _) => Ok ( ( ) ) ,
192- None => Err ( io:: Error :: new (
193- io:: ErrorKind :: NotConnected ,
194- "disconnected" ,
195- ) ) ,
196- } ?;
197- Poll :: Ready ( Ok :: < _ , io:: Error > ( ( ) ) )
198- } )
199- . await ?;
200- }
201- Err ( e) => return Err ( e) ,
202- }
203- }
38+ } ) ;
39+ super :: stream:: handshake ( & io. filter ( ) . session , & io) . await ?;
40+ Ok ( io)
20441 }
20542}
0 commit comments