diff --git a/Cargo.toml b/Cargo.toml index 6ad44e4be..351a56c10 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ futures-util = { version = "0.3", default-features = false, features = ["sink"] futures-channel = { version = "0.3.17", features = ["sink"]} headers = "0.3.5" http = "0.2" -hyper = { version = "0.14", features = ["stream", "server", "http1", "http2", "tcp", "client"] } +hyper = { version = "0.14.19", features = ["stream", "server", "http1", "http2", "tcp", "client"] } log = "0.4" mime = "0.3" mime_guess = "2.0.0" diff --git a/src/filter/service.rs b/src/filter/service.rs index 3de12a02e..fa1a36f6f 100644 --- a/src/filter/service.rs +++ b/src/filter/service.rs @@ -73,11 +73,12 @@ where pub(crate) fn call_with_addr( &self, req: Request, + local_addr: Option, remote_addr: Option, ) -> FilteredFuture { debug_assert!(!route::is_set(), "nested route::set calls"); - let route = Route::new(req, remote_addr); + let route = Route::new(req, local_addr, remote_addr); let fut = route::set(&route, || self.filter.filter(super::Internal)); FilteredFuture { future: fut, route } } @@ -99,7 +100,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - self.call_with_addr(req, None) + self.call_with_addr(req, None, None) } } diff --git a/src/filters/addr.rs b/src/filters/addr.rs index 3d630705a..6dbe69000 100644 --- a/src/filters/addr.rs +++ b/src/filters/addr.rs @@ -5,6 +5,26 @@ use std::net::SocketAddr; use crate::filter::{filter_fn_one, Filter}; +/// Creates a `Filter` to get the local address of the connection. +/// +/// If the underlying transport doesn't use socket addresses, this will yield +/// `None`. +/// +/// # Example +/// +/// ``` +/// use std::net::SocketAddr; +/// use warp::Filter; +/// +/// let route = warp::addr::local() +/// .map(|addr: Option| { +/// println!("local address = {:?}", addr); +/// }); +/// ``` +pub fn local() -> impl Filter,), Error = Infallible> + Copy { + filter_fn_one(|route| futures_util::future::ok(route.local_addr())) +} + /// Creates a `Filter` to get the remote address of the connection. /// /// If the underlying transport doesn't use socket addresses, this will yield diff --git a/src/filters/log.rs b/src/filters/log.rs index 3790fd8a8..49252d69f 100644 --- a/src/filters/log.rs +++ b/src/filters/log.rs @@ -109,6 +109,11 @@ where } impl<'a> Info<'a> { + /// View the local `SocketAddr` of the request. + pub fn local_addr(&self) -> Option { + self.route.local_addr() + } + /// View the remote `SocketAddr` of the request. pub fn remote_addr(&self) -> Option { self.route.remote_addr() diff --git a/src/filters/trace.rs b/src/filters/trace.rs index 60686b430..9bc3be23a 100644 --- a/src/filters/trace.rs +++ b/src/filters/trace.rs @@ -157,6 +157,11 @@ where } impl<'a> Info<'a> { + /// View the local `SocketAddr` of the request. + pub fn local_addr(&self) -> Option { + self.route.local_addr() + } + /// View the remote `SocketAddr` of the request. pub fn remote_addr(&self) -> Option { self.route.remote_addr() diff --git a/src/route.rs b/src/route.rs index afbac4d8b..819225c12 100644 --- a/src/route.rs +++ b/src/route.rs @@ -30,6 +30,7 @@ where #[derive(Debug)] pub(crate) struct Route { body: BodyState, + local_addr: Option, remote_addr: Option, req: Request, segments_index: usize, @@ -42,7 +43,9 @@ enum BodyState { } impl Route { - pub(crate) fn new(req: Request, remote_addr: Option) -> RefCell { + pub(crate) fn new(req: Request, + local_addr: Option, + remote_addr: Option) -> RefCell { let segments_index = if req.uri().path().starts_with('/') { // Skip the beginning slash. 1 @@ -52,6 +55,7 @@ impl Route { RefCell::new(Route { body: BodyState::Ready, + local_addr, remote_addr, req, segments_index, @@ -123,6 +127,10 @@ impl Route { self.segments_index = index; } + pub(crate) fn local_addr(&self) -> Option { + self.local_addr + } + pub(crate) fn remote_addr(&self) -> Option { self.remote_addr } diff --git a/src/server.rs b/src/server.rs index 929d96eb3..0b7c6174f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -55,9 +55,10 @@ macro_rules! into_service { let inner = crate::service($into); make_service_fn(move |transport| { let inner = inner.clone(); + let local_addr = Transport::local_addr(transport); let remote_addr = Transport::remote_addr(transport); future::ok::<_, Infallible>(service_fn(move |req| { - inner.call_with_addr(req, remote_addr) + inner.call_with_addr(req, local_addr, remote_addr) })) }) }}; diff --git a/src/test.rs b/src/test.rs index ca2710fae..496d8ac85 100644 --- a/src/test.rs +++ b/src/test.rs @@ -123,6 +123,7 @@ use self::inner::OneOrTuple; /// Starts a new test `RequestBuilder`. pub fn request() -> RequestBuilder { RequestBuilder { + local_addr: None, remote_addr: None, req: Request::default(), } @@ -140,6 +141,7 @@ pub fn ws() -> WsBuilder { #[must_use = "RequestBuilder does nothing on its own"] #[derive(Debug)] pub struct RequestBuilder { + local_addr: Option, remote_addr: Option, req: Request, } @@ -237,6 +239,22 @@ impl RequestBuilder { self } + /// Set the local address of this request + /// + /// Default is no local address. + /// + /// # Example + /// ``` + /// use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + /// + /// let req = warp::test::request() + /// .local_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)); + /// ``` + pub fn local_addr(mut self, addr: SocketAddr) -> Self { + self.local_addr = Some(addr); + self + } + /// Set the remote address of this request /// /// Default is no remote address. @@ -376,7 +394,7 @@ impl RequestBuilder { // TODO: de-duplicate this and apply_filter() assert!(!route::is_set(), "nested test filter calls"); - let route = Route::new(self.req, self.remote_addr); + let route = Route::new(self.req, self.local_addr, self.remote_addr); let mut fut = Box::pin( route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| { let res = match result { @@ -405,7 +423,7 @@ impl RequestBuilder { { assert!(!route::is_set(), "nested test filter calls"); - let route = Route::new(self.req, self.remote_addr); + let route = Route::new(self.req, self.local_addr, self.remote_addr); let mut fut = Box::pin(route::set(&route, move || { f.filter(crate::filter::Internal) })); diff --git a/src/tls.rs b/src/tls.rs index aa7438752..d1eaf6284 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -281,6 +281,10 @@ impl Read for LazyFile { } impl Transport for TlsStream { + fn local_addr(&self) -> Option { + Some(self.local_addr) + } + fn remote_addr(&self) -> Option { Some(self.remote_addr) } @@ -296,15 +300,18 @@ enum State { // TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first pub(crate) struct TlsStream { state: State, + local_addr: SocketAddr, remote_addr: SocketAddr, } impl TlsStream { fn new(stream: AddrStream, config: Arc) -> TlsStream { + let local_addr = stream.local_addr(); let remote_addr = stream.remote_addr(); let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); TlsStream { state: State::Handshaking(accept), + local_addr, remote_addr, } } diff --git a/src/transport.rs b/src/transport.rs index be553e706..0ce84b034 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -7,10 +7,15 @@ use hyper::server::conn::AddrStream; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub trait Transport: AsyncRead + AsyncWrite { + fn local_addr(&self) -> Option; fn remote_addr(&self) -> Option; } impl Transport for AddrStream { + fn local_addr(&self) -> Option { + Some(self.local_addr()) + } + fn remote_addr(&self) -> Option { Some(self.remote_addr()) } @@ -47,6 +52,10 @@ impl AsyncWrite for LiftIo { } impl Transport for LiftIo { + fn local_addr(&self) -> Option { + None + } + fn remote_addr(&self) -> Option { None } diff --git a/tests/addr.rs b/tests/addr.rs index 12fc46936..67b4de436 100644 --- a/tests/addr.rs +++ b/tests/addr.rs @@ -2,6 +2,27 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +#[tokio::test] +async fn local_addr_missing() { + let extract_local_addr = warp::addr::local(); + + let req = warp::test::request(); + let resp = req.filter(&extract_local_addr).await.unwrap(); + assert_eq!(resp, None) +} + +#[tokio::test] +async fn local_addr_present() { + let extract_local_addr = warp::addr::local(); + + let req = warp::test::request().local_addr("1.2.3.4:5678".parse().unwrap()); + let resp = req.filter(&extract_local_addr).await.unwrap(); + assert_eq!( + resp, + Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 5678)) + ) +} + #[tokio::test] async fn remote_addr_missing() { let extract_remote_addr = warp::addr::remote();