diff --git a/Cargo.toml b/Cargo.toml
index f19f8c6d4..582818548 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -533,7 +533,7 @@ harness = false
[[bench]]
name = "e2e_http_client_server"
path = "benches/e2e_http_client_server.rs"
-required-features = ["http-full", "rustls", "boring"]
+required-features = ["http-full", "rustls", "boring", "socks5"]
harness = false
[[example]]
diff --git a/benches/e2e_http_client_server.rs b/benches/e2e_http_client_server.rs
index a20759406..5ffb65777 100644
--- a/benches/e2e_http_client_server.rs
+++ b/benches/e2e_http_client_server.rs
@@ -1,9 +1,10 @@
//! ```sh
-//! cargo bench --bench e2e_http_client_server --features http-full,rustls,boring
+//! cargo bench --bench e2e_http_client_server --features http-full,rustls,boring,socks5
//! ```
use std::{
convert::Infallible,
+ slice,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
@@ -16,35 +17,47 @@ use rama::{
error::{BoxError, extra::OpaqueError},
extensions::ExtensionsMut,
http::{
- HeaderName, HeaderValue, Request, Response, Version,
+ Body, HeaderName, HeaderValue, Request, Response, StatusCode, Version,
body::util::BodyExt,
client::EasyHttpWebClient,
+ io::upgrade::Upgraded,
layer::{
compression::CompressionLayer,
cors::CorsLayer,
decompression::DecompressionLayer,
map_response_body::MapResponseBodyLayer,
+ remove_header::{RemoveRequestHeaderLayer, RemoveResponseHeaderLayer},
required_header::{AddRequiredRequestHeadersLayer, AddRequiredResponseHeadersLayer},
set_header::SetResponseHeaderLayer,
trace::TraceLayer,
+ upgrade::UpgradeLayer,
},
+ matcher::MethodMatcher,
server::HttpServer,
service::{
client::HttpClientExt as _,
web::{WebService, response::IntoResponse as _},
},
},
+ layer::ConsumeErrLayer,
net::{
- address::SocketAddress,
+ Protocol,
+ address::{ProxyAddress, SocketAddress},
+ http::RequestContext,
+ proxy::ProxyTarget,
tls::{
ApplicationProtocol,
client::ServerVerifyMode,
server::{SelfSignedData, ServerAuth, ServerConfig},
},
+ user::credentials::{ProxyCredential, basic},
},
+ proxy::socks5::{Socks5Acceptor, server::LazyConnector},
rt::Executor,
- service::BoxService,
+ service::{BoxService, service_fn},
+ tcp::client::service::Forwarder,
tcp::server::TcpListener,
+ telemetry::tracing::{self},
tls::{boring, rustls},
};
@@ -93,24 +106,42 @@ enum Tls {
Boring,
}
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum Proxy {
+ None,
+ Http,
+ HttpMitm,
+ Socks5,
+ Socks5Mitm,
+}
+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct TestParameters {
version: HttpVersion,
tls: Tls,
+ proxy: Proxy,
server: Size,
client: Size,
}
const VERSIONS: [HttpVersion; 2] = [HttpVersion::Http1, HttpVersion::Http2];
const TLSES: [Tls; 3] = [Tls::None, Tls::Rustls, Tls::Boring];
+const PROXIES: [Proxy; 5] = [
+ Proxy::None,
+ Proxy::Http,
+ Proxy::HttpMitm,
+ Proxy::Socks5,
+ Proxy::Socks5Mitm,
+];
const SIZES: [Size; 2] = [Size::Small, Size::Large];
-const N: usize = VERSIONS.len() * TLSES.len() * SIZES.len() * SIZES.len();
+const N: usize = VERSIONS.len() * TLSES.len() * PROXIES.len() * SIZES.len() * SIZES.len();
const fn build_test_matrix() -> [TestParameters; N] {
let placeholder = TestParameters {
version: VERSIONS[0],
tls: TLSES[0],
+ proxy: PROXIES[0],
server: SIZES[0],
client: SIZES[0],
};
@@ -122,20 +153,25 @@ const fn build_test_matrix() -> [TestParameters; N] {
while vi < VERSIONS.len() {
let mut ti = 0usize;
while ti < TLSES.len() {
- let mut si = 0usize;
- while si < SIZES.len() {
- let mut ci = 0usize;
- while ci < SIZES.len() {
- out[i] = TestParameters {
- version: VERSIONS[vi],
- tls: TLSES[ti],
- server: SIZES[si],
- client: SIZES[ci],
- };
- i += 1;
- ci += 1;
+ let mut pi = 0usize;
+ while pi < PROXIES.len() {
+ let mut si = 0usize;
+ while si < SIZES.len() {
+ let mut ci = 0usize;
+ while ci < SIZES.len() {
+ out[i] = TestParameters {
+ version: VERSIONS[vi],
+ tls: TLSES[ti],
+ proxy: PROXIES[pi],
+ server: SIZES[si],
+ client: SIZES[ci],
+ };
+ i += 1;
+ ci += 1;
+ }
+ si += 1;
}
- si += 1;
+ pi += 1;
}
ti += 1;
}
@@ -196,7 +232,130 @@ where
}
}
-fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddress {
+fn get_rustls_tls_data(params: TestParameters) -> rustls::server::TlsAcceptorData {
+ let proto = match params.version {
+ HttpVersion::Http1 => ApplicationProtocol::HTTP_11,
+ HttpVersion::Http2 => ApplicationProtocol::HTTP_2,
+ };
+
+ rustls::server::TlsAcceptorDataBuilder::try_new_self_signed(SelfSignedData::default())
+ .unwrap()
+ .with_alpn_protocols(&[proto])
+ .build()
+}
+
+fn get_boring_tls_data(params: TestParameters) -> boring::server::TlsAcceptorData {
+ let proto = match params.version {
+ HttpVersion::Http1 => ApplicationProtocol::HTTP_11,
+ HttpVersion::Http2 => ApplicationProtocol::HTTP_2,
+ };
+
+ let config = ServerConfig {
+ application_layer_protocol_negotiation: Some(vec![proto]),
+ ..ServerConfig::new(ServerAuth::SelfSigned(SelfSignedData::default()))
+ };
+ boring::server::TlsAcceptorData::try_from(config).unwrap()
+}
+
+async fn http_connect_accept(mut req: Request) -> Result<(Response, Request), Response> {
+ match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) {
+ Ok(authority) => {
+ tracing::info!(
+ server.address = %authority.host,
+ server.port = authority.port,
+ "accept CONNECT (lazy): insert proxy target into context",
+ );
+ req.extensions_mut().insert(ProxyTarget(authority));
+ }
+ Err(err) => {
+ tracing::error!("error extracting authority: {err:?}");
+ return Err(StatusCode::BAD_REQUEST.into_response());
+ }
+ }
+
+ Ok((StatusCode::OK.into_response(), req))
+}
+
+fn get_http_proxy_service_boxed(params: TestParameters) -> BoxService
+where
+ Input: ExtensionsMut + AsyncRead + AsyncWrite + Send + 'static,
+{
+ let handler = move |req: Request| async move {
+ let client = get_inner_client(params.version, params.tls);
+ match client.serve(req).await {
+ Ok(resp) => {
+ tracing::info!(status_code = %resp.status(), "proxy received response");
+ Ok(resp)
+ }
+ Err(err) => {
+ tracing::error!("error in client request: {err:?}");
+ Ok(Response::builder()
+ .status(StatusCode::INTERNAL_SERVER_ERROR)
+ .body(Body::empty())
+ .unwrap())
+ }
+ }
+ };
+
+ let connect_proxy = move |upgraded: Upgraded| async move {
+ let http_service = (
+ MapResponseBodyLayer::new_boxed_streaming_body(),
+ TraceLayer::new_for_http(),
+ ConsumeErrLayer::default(),
+ RemoveResponseHeaderLayer::hop_by_hop(),
+ RemoveRequestHeaderLayer::hop_by_hop(),
+ CompressionLayer::new(),
+ AddRequiredRequestHeadersLayer::new(),
+ )
+ .into_layer(service_fn(handler));
+ let http_transport_service = HttpServer::auto(Executor::default()).service(http_service);
+
+ match params.tls {
+ Tls::Rustls => {
+ let data = get_rustls_tls_data(params);
+ let https_service =
+ rustls::server::TlsAcceptorLayer::new(data).into_layer(http_transport_service);
+ https_service.serve(upgraded).await.expect("infallible");
+ }
+ Tls::Boring => {
+ let data = get_boring_tls_data(params);
+ let https_service =
+ boring::server::TlsAcceptorLayer::new(data).into_layer(http_transport_service);
+ https_service.serve(upgraded).await.expect("infallible");
+ }
+ Tls::None => panic!("Cannot be called with TLS none"),
+ };
+
+ Ok::<(), Infallible>(())
+ };
+
+ let http_service = (
+ TraceLayer::new_for_http(),
+ CompressionLayer::new(),
+ if matches!(params.proxy, Proxy::HttpMitm | Proxy::Socks5Mitm) {
+ UpgradeLayer::new(
+ Executor::default(),
+ MethodMatcher::CONNECT,
+ service_fn(http_connect_accept),
+ service_fn(connect_proxy),
+ )
+ } else {
+ UpgradeLayer::new(
+ Executor::default(),
+ MethodMatcher::CONNECT,
+ service_fn(http_connect_accept),
+ ConsumeErrLayer::default().into_layer(Forwarder::ctx(Executor::default())),
+ )
+ },
+ )
+ .layer(service_fn(handler));
+
+ HttpServer::auto(Executor::default())
+ .service(http_service)
+ .boxed()
+}
+
+fn spawn_http_proxy(params: TestParameters) -> SocketAddress {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let ready = Arc::new(AtomicBool::new(false));
@@ -212,10 +371,80 @@ fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddre
let async_listener =
TcpListener::try_from_std_tcp_listener(listener, Executor::default()).unwrap();
- let proto = match params.version {
- HttpVersion::Http1 => ApplicationProtocol::HTTP_11,
- HttpVersion::Http2 => ApplicationProtocol::HTTP_2,
- };
+ let socks5_acceptor_base = Socks5Acceptor::new(Executor::default())
+ .with_authorizer(basic!("john", "secret").into_authorizer());
+
+ match params.tls {
+ Tls::None => {
+ let service = get_http_proxy_service_boxed(params);
+
+ ready_worker.store(true, Ordering::Release);
+
+ if matches!(params.proxy, Proxy::Socks5 | Proxy::Socks5Mitm) {
+ let socks5_acceptor =
+ socks5_acceptor_base.with_connector(LazyConnector::new(service));
+ async_listener.serve(socks5_acceptor).await
+ } else {
+ async_listener.serve(service).await
+ }
+ }
+ Tls::Rustls => {
+ let service = get_http_proxy_service_boxed(params);
+
+ let data = get_rustls_tls_data(params);
+ ready_worker.store(true, Ordering::Release);
+ let tls_acceptor =
+ rustls::server::TlsAcceptorLayer::new(data).into_layer(service);
+
+ if matches!(params.proxy, Proxy::Socks5 | Proxy::Socks5Mitm) {
+ let socks5_acceptor =
+ socks5_acceptor_base.with_connector(LazyConnector::new(tls_acceptor));
+ async_listener.serve(socks5_acceptor).await
+ } else {
+ async_listener.serve(tls_acceptor).await
+ }
+ }
+ Tls::Boring => {
+ let service = get_http_proxy_service_boxed(params);
+
+ let data = get_boring_tls_data(params);
+ ready_worker.store(true, Ordering::Release);
+ let tls_acceptor =
+ boring::server::TlsAcceptorLayer::new(data).into_layer(service);
+
+ if matches!(params.proxy, Proxy::Socks5 | Proxy::Socks5Mitm) {
+ let socks5_acceptor =
+ socks5_acceptor_base.with_connector(LazyConnector::new(tls_acceptor));
+ async_listener.serve(socks5_acceptor).await
+ } else {
+ async_listener.serve(tls_acceptor).await
+ }
+ }
+ }
+ });
+ });
+
+ while !ready.load(Ordering::Acquire) {
+ std::thread::yield_now();
+ }
+ addr.into()
+}
+
+fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddress {
+ let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
+ let addr = listener.local_addr().unwrap();
+ let ready = Arc::new(AtomicBool::new(false));
+ let ready_worker = ready.clone();
+
+ std::thread::spawn(move || {
+ let rt = tokio::runtime::Builder::new_multi_thread()
+ .worker_threads(2)
+ .enable_all()
+ .build()
+ .unwrap();
+ rt.block_on(async move {
+ let async_listener =
+ TcpListener::try_from_std_tcp_listener(listener, Executor::default()).unwrap();
match params.tls {
Tls::None => {
@@ -228,12 +457,7 @@ fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddre
Tls::Rustls => {
let service = get_http_service_boxed(params, body_content);
- let data = rustls::server::TlsAcceptorDataBuilder::try_new_self_signed(
- SelfSignedData::default(),
- )
- .unwrap()
- .with_alpn_protocols(&[proto])
- .build();
+ let data = get_rustls_tls_data(params);
ready_worker.store(true, Ordering::Release);
@@ -244,11 +468,7 @@ fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddre
Tls::Boring => {
let service = get_http_service_boxed(params, body_content);
- let config = ServerConfig {
- application_layer_protocol_negotiation: Some(vec![proto]),
- ..ServerConfig::new(ServerAuth::SelfSigned(SelfSignedData::default()))
- };
- let data = boring::server::TlsAcceptorData::try_from(config).unwrap();
+ let data = get_boring_tls_data(params);
ready_worker.store(true, Ordering::Release);
@@ -270,10 +490,7 @@ fn get_inner_client(
http: HttpVersion,
tls: Tls,
) -> impl Service {
- let b = EasyHttpWebClient::connector_builder()
- .with_default_transport_connector()
- .without_tls_proxy_support()
- .without_proxy_support();
+ let b = EasyHttpWebClient::connector_builder().with_default_transport_connector();
let proto = match http {
HttpVersion::Http1 => ApplicationProtocol::HTTP_11,
@@ -282,30 +499,49 @@ fn get_inner_client(
match tls {
Tls::None => b
+ .without_tls_proxy_support()
+ .with_proxy_support()
.without_tls_support()
.with_default_http_connector(Executor::default())
.build_client(),
- Tls::Rustls => b
- .with_tls_support_using_rustls(Some(
- rustls::client::TlsConnectorDataBuilder::new()
- .try_with_env_key_logger()
- .unwrap()
- .with_alpn_protocols(&[proto])
- .with_no_cert_verifier()
- .build(),
- ))
- .with_default_http_connector(Executor::default())
- .build_client(),
- Tls::Boring => b
- .with_tls_support_using_boringssl(Some(
- boring::client::TlsConnectorDataBuilder::new()
- .try_with_rama_alpn_protos(&[proto])
- .unwrap()
- .with_server_verify_mode(ServerVerifyMode::Disable)
- .into_shared_builder(),
- ))
- .with_default_http_connector(Executor::default())
- .build_client(),
+ Tls::Rustls => {
+ let tls_config = rustls::client::TlsConnectorDataBuilder::new()
+ .try_with_env_key_logger()
+ .unwrap()
+ .with_alpn_protocols(slice::from_ref(&proto))
+ .with_no_cert_verifier()
+ .with_store_server_certificate_chain(true)
+ .build();
+ let proxy_tls_config = rustls::client::TlsConnectorDataBuilder::new()
+ .try_with_env_key_logger()
+ .unwrap()
+ .with_alpn_protocols(&[proto])
+ .with_no_cert_verifier()
+ .build();
+ b.with_tls_proxy_support_using_rustls_config(proxy_tls_config)
+ .with_proxy_support()
+ .with_tls_support_using_rustls(Some(tls_config))
+ .with_default_http_connector(Executor::default())
+ .build_client()
+ }
+ Tls::Boring => {
+ let tls_config = boring::client::TlsConnectorDataBuilder::new()
+ .try_with_rama_alpn_protos(slice::from_ref(&proto))
+ .unwrap()
+ .with_server_verify_mode(ServerVerifyMode::Disable)
+ .with_store_server_certificate_chain(true)
+ .into_shared_builder();
+ let proxy_tls_config = boring::client::TlsConnectorDataBuilder::new()
+ .try_with_rama_alpn_protos(&[proto])
+ .unwrap()
+ .with_server_verify_mode(ServerVerifyMode::Disable)
+ .into_shared_builder();
+ b.with_tls_proxy_support_using_boringssl_config(proxy_tls_config)
+ .with_proxy_support()
+ .with_tls_support_using_boringssl(Some(tls_config))
+ .with_default_http_connector(Executor::default())
+ .build_client()
+ }
}
}
@@ -322,7 +558,6 @@ fn bench_http_transport(bencher: divan::Bencher, params: TestParameters) {
let client_bytes = params.client.rnd_bytes();
let client_bytes_count = client_bytes.len();
- let address = spawn_http_server(params, server_bytes);
let scheme = if matches!(params.tls, Tls::None) {
"http"
} else {
@@ -333,8 +568,15 @@ fn bench_http_transport(bencher: divan::Bencher, params: TestParameters) {
} else {
"large"
};
+
+ let address = spawn_http_server(params, server_bytes);
let url = format!("{scheme}://{address}/{endpoint}");
+ let mut address_proxy = SocketAddress::default_ipv4(0);
+ if params.proxy != Proxy::None {
+ address_proxy = spawn_http_proxy(params);
+ }
+
bencher
.with_inputs(|| {
let client = (
@@ -351,16 +593,28 @@ fn bench_http_transport(bencher: divan::Bencher, params: TestParameters) {
})
.bench_local_values(|(client, body)| {
rt.block_on(async {
- let resp = client
+ let req = client
.post(&url)
.version(match params.version {
HttpVersion::Http1 => Version::HTTP_11,
HttpVersion::Http2 => Version::HTTP_2,
})
- .body(body)
- .send()
- .await
- .expect("Request failed");
+ .body(body);
+
+ let req_with_maybe_proxy = match params.proxy {
+ Proxy::None => req,
+ Proxy::Http | Proxy::HttpMitm => req.extension(
+ ProxyAddress::try_from(format!("{scheme}://{}", address_proxy.clone()))
+ .unwrap(),
+ ),
+ Proxy::Socks5 | Proxy::Socks5Mitm => req.extension(ProxyAddress {
+ protocol: Some(Protocol::SOCKS5),
+ address: address_proxy.into(),
+ credential: Some(ProxyCredential::Basic(basic!("john", "secret"))),
+ }),
+ };
+
+ let resp = req_with_maybe_proxy.send().await.expect("Request failed");
let _ = resp.into_body().collect().await;
});
});