Skip to content

Commit 2a41619

Browse files
authored
Refactor rustls support (#621)
1 parent cf187b5 commit 2a41619

File tree

6 files changed

+265
-379
lines changed

6 files changed

+265
-379
lines changed

ntex-tls/CHANGES.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changes
22

3+
## [2.6.0] - 2025-07-29
4+
5+
* Fix the rustls handshake negotiation #620
6+
7+
* Refactor rustls support
8+
39
## [2.5.1] - 2025-07-28
410

511
* Fix rustls filter impls

ntex-tls/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "ntex-tls"
3-
version = "2.5.1"
3+
version = "2.6.0"
44
authors = ["ntex contributors <team@ntex.rs>"]
55
description = "An implementation of SSL streams for ntex backed by OpenSSL"
66
keywords = ["network", "framework", "async", "futures"]
@@ -27,7 +27,7 @@ rustls-ring = ["tls_rust", "tls_rust/ring", "tls_rust/std"]
2727

2828
[dependencies]
2929
ntex-bytes = "0.1"
30-
ntex-io = "2.13"
30+
ntex-io = "2.14"
3131
ntex-util = "2.5"
3232
ntex-service = "3.5"
3333
ntex-net = "2"

ntex-tls/src/rustls/client.rs

Lines changed: 10 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
//! An implementation of SSL streams for ntex backed by OpenSSL
2-
use std::io::{self, Read as IoRead, Write as IoWrite};
3-
use std::{any, cell::RefCell, future::poll_fn, sync::Arc, task::ready, task::Poll};
2+
use std::{any, cell::RefCell, io, sync::Arc};
43

5-
use ntex_bytes::BufMut;
6-
use ntex_io::{types, Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
4+
use ntex_io::{Filter, FilterLayer, Io, Layer, ReadBuf, WriteBuf};
75
use tls_rust::{pki_types::ServerName, ClientConfig, ClientConnection};
86

9-
use super::{PeerCert, PeerCertChain, Wrapper};
7+
use super::Stream;
108

119
#[derive(Debug)]
1210
/// An implementation of SSL streams
@@ -16,110 +14,15 @@ pub struct TlsClientFilter {
1614

1715
impl FilterLayer for TlsClientFilter {
1816
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
19-
const H2: &[u8] = b"h2";
20-
21-
if id == any::TypeId::of::<types::HttpProtocol>() {
22-
let h2 = self
23-
.session
24-
.borrow()
25-
.alpn_protocol()
26-
.map(|protos| protos.windows(2).any(|w| w == H2))
27-
.unwrap_or(false);
28-
29-
let proto = if h2 {
30-
types::HttpProtocol::Http2
31-
} else {
32-
types::HttpProtocol::Http1
33-
};
34-
Some(Box::new(proto))
35-
} else if id == any::TypeId::of::<PeerCert<'_>>() {
36-
if let Some(cert_chain) = self.session.borrow().peer_certificates() {
37-
if let Some(cert) = cert_chain.first() {
38-
Some(Box::new(PeerCert(cert.to_owned())))
39-
} else {
40-
None
41-
}
42-
} else {
43-
None
44-
}
45-
} else if id == any::TypeId::of::<PeerCertChain<'_>>() {
46-
if let Some(cert_chain) = self.session.borrow().peer_certificates() {
47-
Some(Box::new(PeerCertChain(cert_chain.to_vec())))
48-
} else {
49-
None
50-
}
51-
} else {
52-
None
53-
}
17+
Stream::new(&mut *self.session.borrow_mut()).query(id)
5418
}
5519

5620
fn process_read_buf(&self, buf: &ReadBuf<'_>) -> io::Result<usize> {
57-
let mut session = self.session.borrow_mut();
58-
let mut new_bytes = 0;
59-
60-
// get processed buffer
61-
buf.with_src(|src| {
62-
if let Some(src) = src {
63-
buf.with_dst(|dst| {
64-
loop {
65-
let mut cursor = io::Cursor::new(&src);
66-
let n = match session.read_tls(&mut cursor) {
67-
Ok(n) => n,
68-
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
69-
break
70-
}
71-
Err(err) => return Err(err),
72-
};
73-
src.split_to(n);
74-
let state = session
75-
.process_new_packets()
76-
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
77-
78-
let new_b = state.plaintext_bytes_to_read();
79-
if new_b > 0 {
80-
dst.reserve(new_b);
81-
let chunk: &mut [u8] =
82-
unsafe { std::mem::transmute(&mut *dst.chunk_mut()) };
83-
let v = session.reader().read(chunk)?;
84-
unsafe { dst.advance_mut(v) };
85-
new_bytes += v;
86-
} else {
87-
break;
88-
}
89-
}
90-
Ok::<_, io::Error>(())
91-
})?;
92-
}
93-
Ok(new_bytes)
94-
})
21+
Stream::new(&mut *self.session.borrow_mut()).process_read_buf(buf)
9522
}
9623

9724
fn process_write_buf(&self, buf: &WriteBuf<'_>) -> io::Result<()> {
98-
buf.with_src(|src| {
99-
if let Some(src) = src {
100-
let mut io = Wrapper(buf);
101-
let mut session = self.session.borrow_mut();
102-
103-
'outer: loop {
104-
if !src.is_empty() {
105-
src.split_to(session.writer().write(src)?);
106-
107-
loop {
108-
match session.write_tls(&mut io) {
109-
Ok(0) => continue 'outer,
110-
Ok(_) => continue,
111-
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
112-
break
113-
}
114-
Err(err) => return Err(err),
115-
}
116-
}
117-
}
118-
break;
119-
}
120-
}
121-
Ok(())
122-
})
25+
Stream::new(&mut *self.session.borrow_mut()).process_write_buf(buf)
12326
}
12427
}
12528

@@ -130,76 +33,10 @@ impl TlsClientFilter {
13033
domain: ServerName<'static>,
13134
) -> Result<Io<Layer<TlsClientFilter, F>>, io::Error> {
13235
let session = ClientConnection::new(cfg, domain).map_err(io::Error::other)?;
133-
let filter = TlsClientFilter {
36+
let io = io.add_filter(TlsClientFilter {
13437
session: RefCell::new(session),
135-
};
136-
let io = io.add_filter(filter);
137-
138-
let filter = io.filter();
139-
loop {
140-
let (result, handshaking) = io.with_buf(|buf| {
141-
let mut wrp = Wrapper(buf);
142-
let mut session = filter.session.borrow_mut();
143-
let mut result = Err(io::Error::new(io::ErrorKind::WouldBlock, ""));
144-
145-
while session.wants_write() {
146-
result = session.write_tls(&mut wrp).map(|_| ());
147-
if result.is_err() {
148-
break;
149-
}
150-
}
151-
if session.wants_read() {
152-
let has_data = buf.with_read_buf(|rbuf| {
153-
rbuf.with_src(|b| {
154-
b.as_ref().map(|b| !b.is_empty()).unwrap_or_default()
155-
})
156-
});
157-
158-
if has_data {
159-
result = match session.read_tls(&mut wrp) {
160-
Ok(0) => Err(io::Error::new(
161-
io::ErrorKind::NotConnected,
162-
"disconnected",
163-
)),
164-
Ok(_) => Ok(()),
165-
Err(e) => Err(e),
166-
};
167-
168-
session.process_new_packets().map_err(|err| {
169-
// In case we have an alert to send describing this error,
170-
// try a last-gasp write -- but don't predate the primary
171-
// error.
172-
let _ = session.write_tls(&mut wrp);
173-
io::Error::new(io::ErrorKind::InvalidData, err)
174-
})?;
175-
} else {
176-
result = Err(io::Error::new(io::ErrorKind::WouldBlock, ""));
177-
}
178-
}
179-
180-
Ok::<_, io::Error>((result, session.is_handshaking()))
181-
})??;
182-
183-
match result {
184-
Ok(()) => return Ok(io),
185-
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
186-
if !handshaking {
187-
return Ok(io);
188-
}
189-
poll_fn(|cx| {
190-
match ready!(io.poll_read_notify(cx))? {
191-
Some(_) => Ok(()),
192-
None => Err(io::Error::new(
193-
io::ErrorKind::NotConnected,
194-
"disconnected",
195-
)),
196-
}?;
197-
Poll::Ready(Ok::<_, io::Error>(()))
198-
})
199-
.await?;
200-
}
201-
Err(e) => return Err(e),
202-
}
203-
}
38+
});
39+
super::stream::handshake(&io.filter().session, &io).await?;
40+
Ok(io)
20441
}
20542
}

ntex-tls/src/rustls/mod.rs

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,23 @@
11
//! An implementation of SSL streams for ntex backed by OpenSSL
2-
use std::{cmp, io};
3-
4-
use ntex_io::WriteBuf;
52
use tls_rust::pki_types::CertificateDer;
63

74
mod accept;
85
mod client;
96
mod connect;
107
mod server;
8+
mod stream;
119

1210
pub use self::accept::{TlsAcceptor, TlsAcceptorService};
1311
pub use self::client::TlsClientFilter;
1412
pub use self::connect::TlsConnector;
1513
pub use self::server::TlsServerFilter;
1614

15+
use self::stream::Stream;
16+
1717
/// Connection's peer cert
1818
#[derive(Debug)]
1919
pub struct PeerCert<'a>(pub CertificateDer<'a>);
2020

2121
/// Connection's peer cert chain
2222
#[derive(Debug)]
2323
pub struct PeerCertChain<'a>(pub Vec<CertificateDer<'a>>);
24-
25-
pub(crate) struct Wrapper<'a, 'b>(&'a WriteBuf<'b>);
26-
27-
impl io::Read for Wrapper<'_, '_> {
28-
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> {
29-
self.0.with_read_buf(|buf| {
30-
buf.with_src(|buf| {
31-
if let Some(buf) = buf {
32-
let len = cmp::min(buf.len(), dst.len());
33-
if len > 0 {
34-
dst[..len].copy_from_slice(&buf.split_to(len));
35-
return Ok(len);
36-
}
37-
}
38-
Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
39-
})
40-
})
41-
}
42-
}
43-
44-
impl io::Write for Wrapper<'_, '_> {
45-
fn write(&mut self, src: &[u8]) -> io::Result<usize> {
46-
self.0.with_dst(|buf| buf.extend_from_slice(src));
47-
Ok(src.len())
48-
}
49-
50-
fn flush(&mut self) -> io::Result<()> {
51-
Ok(())
52-
}
53-
}

0 commit comments

Comments
 (0)