diff --git a/Cargo.toml b/Cargo.toml index f19f8c6d4..582818548 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -533,7 +533,7 @@ harness = false [[bench]] name = "e2e_http_client_server" path = "benches/e2e_http_client_server.rs" -required-features = ["http-full", "rustls", "boring"] +required-features = ["http-full", "rustls", "boring", "socks5"] harness = false [[example]] diff --git a/benches/e2e_http_client_server.rs b/benches/e2e_http_client_server.rs index a20759406..5ffb65777 100644 --- a/benches/e2e_http_client_server.rs +++ b/benches/e2e_http_client_server.rs @@ -1,9 +1,10 @@ //! ```sh -//! cargo bench --bench e2e_http_client_server --features http-full,rustls,boring +//! cargo bench --bench e2e_http_client_server --features http-full,rustls,boring,socks5 //! ``` use std::{ convert::Infallible, + slice, sync::{ Arc, atomic::{AtomicBool, Ordering}, @@ -16,35 +17,47 @@ use rama::{ error::{BoxError, extra::OpaqueError}, extensions::ExtensionsMut, http::{ - HeaderName, HeaderValue, Request, Response, Version, + Body, HeaderName, HeaderValue, Request, Response, StatusCode, Version, body::util::BodyExt, client::EasyHttpWebClient, + io::upgrade::Upgraded, layer::{ compression::CompressionLayer, cors::CorsLayer, decompression::DecompressionLayer, map_response_body::MapResponseBodyLayer, + remove_header::{RemoveRequestHeaderLayer, RemoveResponseHeaderLayer}, required_header::{AddRequiredRequestHeadersLayer, AddRequiredResponseHeadersLayer}, set_header::SetResponseHeaderLayer, trace::TraceLayer, + upgrade::UpgradeLayer, }, + matcher::MethodMatcher, server::HttpServer, service::{ client::HttpClientExt as _, web::{WebService, response::IntoResponse as _}, }, }, + layer::ConsumeErrLayer, net::{ - address::SocketAddress, + Protocol, + address::{ProxyAddress, SocketAddress}, + http::RequestContext, + proxy::ProxyTarget, tls::{ ApplicationProtocol, client::ServerVerifyMode, server::{SelfSignedData, ServerAuth, ServerConfig}, }, + user::credentials::{ProxyCredential, basic}, }, + proxy::socks5::{Socks5Acceptor, server::LazyConnector}, rt::Executor, - service::BoxService, + service::{BoxService, service_fn}, + tcp::client::service::Forwarder, tcp::server::TcpListener, + telemetry::tracing::{self}, tls::{boring, rustls}, }; @@ -93,24 +106,42 @@ enum Tls { Boring, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Proxy { + None, + Http, + HttpMitm, + Socks5, + Socks5Mitm, +} + #[derive(Copy, Clone, Debug, Eq, PartialEq)] struct TestParameters { version: HttpVersion, tls: Tls, + proxy: Proxy, server: Size, client: Size, } const VERSIONS: [HttpVersion; 2] = [HttpVersion::Http1, HttpVersion::Http2]; const TLSES: [Tls; 3] = [Tls::None, Tls::Rustls, Tls::Boring]; +const PROXIES: [Proxy; 5] = [ + Proxy::None, + Proxy::Http, + Proxy::HttpMitm, + Proxy::Socks5, + Proxy::Socks5Mitm, +]; const SIZES: [Size; 2] = [Size::Small, Size::Large]; -const N: usize = VERSIONS.len() * TLSES.len() * SIZES.len() * SIZES.len(); +const N: usize = VERSIONS.len() * TLSES.len() * PROXIES.len() * SIZES.len() * SIZES.len(); const fn build_test_matrix() -> [TestParameters; N] { let placeholder = TestParameters { version: VERSIONS[0], tls: TLSES[0], + proxy: PROXIES[0], server: SIZES[0], client: SIZES[0], }; @@ -122,20 +153,25 @@ const fn build_test_matrix() -> [TestParameters; N] { while vi < VERSIONS.len() { let mut ti = 0usize; while ti < TLSES.len() { - let mut si = 0usize; - while si < SIZES.len() { - let mut ci = 0usize; - while ci < SIZES.len() { - out[i] = TestParameters { - version: VERSIONS[vi], - tls: TLSES[ti], - server: SIZES[si], - client: SIZES[ci], - }; - i += 1; - ci += 1; + let mut pi = 0usize; + while pi < PROXIES.len() { + let mut si = 0usize; + while si < SIZES.len() { + let mut ci = 0usize; + while ci < SIZES.len() { + out[i] = TestParameters { + version: VERSIONS[vi], + tls: TLSES[ti], + proxy: PROXIES[pi], + server: SIZES[si], + client: SIZES[ci], + }; + i += 1; + ci += 1; + } + si += 1; } - si += 1; + pi += 1; } ti += 1; } @@ -196,7 +232,130 @@ where } } -fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddress { +fn get_rustls_tls_data(params: TestParameters) -> rustls::server::TlsAcceptorData { + let proto = match params.version { + HttpVersion::Http1 => ApplicationProtocol::HTTP_11, + HttpVersion::Http2 => ApplicationProtocol::HTTP_2, + }; + + rustls::server::TlsAcceptorDataBuilder::try_new_self_signed(SelfSignedData::default()) + .unwrap() + .with_alpn_protocols(&[proto]) + .build() +} + +fn get_boring_tls_data(params: TestParameters) -> boring::server::TlsAcceptorData { + let proto = match params.version { + HttpVersion::Http1 => ApplicationProtocol::HTTP_11, + HttpVersion::Http2 => ApplicationProtocol::HTTP_2, + }; + + let config = ServerConfig { + application_layer_protocol_negotiation: Some(vec![proto]), + ..ServerConfig::new(ServerAuth::SelfSigned(SelfSignedData::default())) + }; + boring::server::TlsAcceptorData::try_from(config).unwrap() +} + +async fn http_connect_accept(mut req: Request) -> Result<(Response, Request), Response> { + match RequestContext::try_from(&req).map(|ctx| ctx.host_with_port()) { + Ok(authority) => { + tracing::info!( + server.address = %authority.host, + server.port = authority.port, + "accept CONNECT (lazy): insert proxy target into context", + ); + req.extensions_mut().insert(ProxyTarget(authority)); + } + Err(err) => { + tracing::error!("error extracting authority: {err:?}"); + return Err(StatusCode::BAD_REQUEST.into_response()); + } + } + + Ok((StatusCode::OK.into_response(), req)) +} + +fn get_http_proxy_service_boxed(params: TestParameters) -> BoxService +where + Input: ExtensionsMut + AsyncRead + AsyncWrite + Send + 'static, +{ + let handler = move |req: Request| async move { + let client = get_inner_client(params.version, params.tls); + match client.serve(req).await { + Ok(resp) => { + tracing::info!(status_code = %resp.status(), "proxy received response"); + Ok(resp) + } + Err(err) => { + tracing::error!("error in client request: {err:?}"); + Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::empty()) + .unwrap()) + } + } + }; + + let connect_proxy = move |upgraded: Upgraded| async move { + let http_service = ( + MapResponseBodyLayer::new_boxed_streaming_body(), + TraceLayer::new_for_http(), + ConsumeErrLayer::default(), + RemoveResponseHeaderLayer::hop_by_hop(), + RemoveRequestHeaderLayer::hop_by_hop(), + CompressionLayer::new(), + AddRequiredRequestHeadersLayer::new(), + ) + .into_layer(service_fn(handler)); + let http_transport_service = HttpServer::auto(Executor::default()).service(http_service); + + match params.tls { + Tls::Rustls => { + let data = get_rustls_tls_data(params); + let https_service = + rustls::server::TlsAcceptorLayer::new(data).into_layer(http_transport_service); + https_service.serve(upgraded).await.expect("infallible"); + } + Tls::Boring => { + let data = get_boring_tls_data(params); + let https_service = + boring::server::TlsAcceptorLayer::new(data).into_layer(http_transport_service); + https_service.serve(upgraded).await.expect("infallible"); + } + Tls::None => panic!("Cannot be called with TLS none"), + }; + + Ok::<(), Infallible>(()) + }; + + let http_service = ( + TraceLayer::new_for_http(), + CompressionLayer::new(), + if matches!(params.proxy, Proxy::HttpMitm | Proxy::Socks5Mitm) { + UpgradeLayer::new( + Executor::default(), + MethodMatcher::CONNECT, + service_fn(http_connect_accept), + service_fn(connect_proxy), + ) + } else { + UpgradeLayer::new( + Executor::default(), + MethodMatcher::CONNECT, + service_fn(http_connect_accept), + ConsumeErrLayer::default().into_layer(Forwarder::ctx(Executor::default())), + ) + }, + ) + .layer(service_fn(handler)); + + HttpServer::auto(Executor::default()) + .service(http_service) + .boxed() +} + +fn spawn_http_proxy(params: TestParameters) -> SocketAddress { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); let ready = Arc::new(AtomicBool::new(false)); @@ -212,10 +371,80 @@ fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddre let async_listener = TcpListener::try_from_std_tcp_listener(listener, Executor::default()).unwrap(); - let proto = match params.version { - HttpVersion::Http1 => ApplicationProtocol::HTTP_11, - HttpVersion::Http2 => ApplicationProtocol::HTTP_2, - }; + let socks5_acceptor_base = Socks5Acceptor::new(Executor::default()) + .with_authorizer(basic!("john", "secret").into_authorizer()); + + match params.tls { + Tls::None => { + let service = get_http_proxy_service_boxed(params); + + ready_worker.store(true, Ordering::Release); + + if matches!(params.proxy, Proxy::Socks5 | Proxy::Socks5Mitm) { + let socks5_acceptor = + socks5_acceptor_base.with_connector(LazyConnector::new(service)); + async_listener.serve(socks5_acceptor).await + } else { + async_listener.serve(service).await + } + } + Tls::Rustls => { + let service = get_http_proxy_service_boxed(params); + + let data = get_rustls_tls_data(params); + ready_worker.store(true, Ordering::Release); + let tls_acceptor = + rustls::server::TlsAcceptorLayer::new(data).into_layer(service); + + if matches!(params.proxy, Proxy::Socks5 | Proxy::Socks5Mitm) { + let socks5_acceptor = + socks5_acceptor_base.with_connector(LazyConnector::new(tls_acceptor)); + async_listener.serve(socks5_acceptor).await + } else { + async_listener.serve(tls_acceptor).await + } + } + Tls::Boring => { + let service = get_http_proxy_service_boxed(params); + + let data = get_boring_tls_data(params); + ready_worker.store(true, Ordering::Release); + let tls_acceptor = + boring::server::TlsAcceptorLayer::new(data).into_layer(service); + + if matches!(params.proxy, Proxy::Socks5 | Proxy::Socks5Mitm) { + let socks5_acceptor = + socks5_acceptor_base.with_connector(LazyConnector::new(tls_acceptor)); + async_listener.serve(socks5_acceptor).await + } else { + async_listener.serve(tls_acceptor).await + } + } + } + }); + }); + + while !ready.load(Ordering::Acquire) { + std::thread::yield_now(); + } + addr.into() +} + +fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddress { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + let ready = Arc::new(AtomicBool::new(false)); + let ready_worker = ready.clone(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .build() + .unwrap(); + rt.block_on(async move { + let async_listener = + TcpListener::try_from_std_tcp_listener(listener, Executor::default()).unwrap(); match params.tls { Tls::None => { @@ -228,12 +457,7 @@ fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddre Tls::Rustls => { let service = get_http_service_boxed(params, body_content); - let data = rustls::server::TlsAcceptorDataBuilder::try_new_self_signed( - SelfSignedData::default(), - ) - .unwrap() - .with_alpn_protocols(&[proto]) - .build(); + let data = get_rustls_tls_data(params); ready_worker.store(true, Ordering::Release); @@ -244,11 +468,7 @@ fn spawn_http_server(params: TestParameters, body_content: Bytes) -> SocketAddre Tls::Boring => { let service = get_http_service_boxed(params, body_content); - let config = ServerConfig { - application_layer_protocol_negotiation: Some(vec![proto]), - ..ServerConfig::new(ServerAuth::SelfSigned(SelfSignedData::default())) - }; - let data = boring::server::TlsAcceptorData::try_from(config).unwrap(); + let data = get_boring_tls_data(params); ready_worker.store(true, Ordering::Release); @@ -270,10 +490,7 @@ fn get_inner_client( http: HttpVersion, tls: Tls, ) -> impl Service { - let b = EasyHttpWebClient::connector_builder() - .with_default_transport_connector() - .without_tls_proxy_support() - .without_proxy_support(); + let b = EasyHttpWebClient::connector_builder().with_default_transport_connector(); let proto = match http { HttpVersion::Http1 => ApplicationProtocol::HTTP_11, @@ -282,30 +499,49 @@ fn get_inner_client( match tls { Tls::None => b + .without_tls_proxy_support() + .with_proxy_support() .without_tls_support() .with_default_http_connector(Executor::default()) .build_client(), - Tls::Rustls => b - .with_tls_support_using_rustls(Some( - rustls::client::TlsConnectorDataBuilder::new() - .try_with_env_key_logger() - .unwrap() - .with_alpn_protocols(&[proto]) - .with_no_cert_verifier() - .build(), - )) - .with_default_http_connector(Executor::default()) - .build_client(), - Tls::Boring => b - .with_tls_support_using_boringssl(Some( - boring::client::TlsConnectorDataBuilder::new() - .try_with_rama_alpn_protos(&[proto]) - .unwrap() - .with_server_verify_mode(ServerVerifyMode::Disable) - .into_shared_builder(), - )) - .with_default_http_connector(Executor::default()) - .build_client(), + Tls::Rustls => { + let tls_config = rustls::client::TlsConnectorDataBuilder::new() + .try_with_env_key_logger() + .unwrap() + .with_alpn_protocols(slice::from_ref(&proto)) + .with_no_cert_verifier() + .with_store_server_certificate_chain(true) + .build(); + let proxy_tls_config = rustls::client::TlsConnectorDataBuilder::new() + .try_with_env_key_logger() + .unwrap() + .with_alpn_protocols(&[proto]) + .with_no_cert_verifier() + .build(); + b.with_tls_proxy_support_using_rustls_config(proxy_tls_config) + .with_proxy_support() + .with_tls_support_using_rustls(Some(tls_config)) + .with_default_http_connector(Executor::default()) + .build_client() + } + Tls::Boring => { + let tls_config = boring::client::TlsConnectorDataBuilder::new() + .try_with_rama_alpn_protos(slice::from_ref(&proto)) + .unwrap() + .with_server_verify_mode(ServerVerifyMode::Disable) + .with_store_server_certificate_chain(true) + .into_shared_builder(); + let proxy_tls_config = boring::client::TlsConnectorDataBuilder::new() + .try_with_rama_alpn_protos(&[proto]) + .unwrap() + .with_server_verify_mode(ServerVerifyMode::Disable) + .into_shared_builder(); + b.with_tls_proxy_support_using_boringssl_config(proxy_tls_config) + .with_proxy_support() + .with_tls_support_using_boringssl(Some(tls_config)) + .with_default_http_connector(Executor::default()) + .build_client() + } } } @@ -322,7 +558,6 @@ fn bench_http_transport(bencher: divan::Bencher, params: TestParameters) { let client_bytes = params.client.rnd_bytes(); let client_bytes_count = client_bytes.len(); - let address = spawn_http_server(params, server_bytes); let scheme = if matches!(params.tls, Tls::None) { "http" } else { @@ -333,8 +568,15 @@ fn bench_http_transport(bencher: divan::Bencher, params: TestParameters) { } else { "large" }; + + let address = spawn_http_server(params, server_bytes); let url = format!("{scheme}://{address}/{endpoint}"); + let mut address_proxy = SocketAddress::default_ipv4(0); + if params.proxy != Proxy::None { + address_proxy = spawn_http_proxy(params); + } + bencher .with_inputs(|| { let client = ( @@ -351,16 +593,28 @@ fn bench_http_transport(bencher: divan::Bencher, params: TestParameters) { }) .bench_local_values(|(client, body)| { rt.block_on(async { - let resp = client + let req = client .post(&url) .version(match params.version { HttpVersion::Http1 => Version::HTTP_11, HttpVersion::Http2 => Version::HTTP_2, }) - .body(body) - .send() - .await - .expect("Request failed"); + .body(body); + + let req_with_maybe_proxy = match params.proxy { + Proxy::None => req, + Proxy::Http | Proxy::HttpMitm => req.extension( + ProxyAddress::try_from(format!("{scheme}://{}", address_proxy.clone())) + .unwrap(), + ), + Proxy::Socks5 | Proxy::Socks5Mitm => req.extension(ProxyAddress { + protocol: Some(Protocol::SOCKS5), + address: address_proxy.into(), + credential: Some(ProxyCredential::Basic(basic!("john", "secret"))), + }), + }; + + let resp = req_with_maybe_proxy.send().await.expect("Request failed"); let _ = resp.into_body().collect().await; }); });