Skip to content

Commit 4925a08

Browse files
author
=
committed
Add local() filter to get the local address of the connection
1 parent 7b07043 commit 4925a08

File tree

11 files changed

+102
-7
lines changed

11 files changed

+102
-7
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ futures-util = { version = "0.3", default-features = false, features = ["sink"]
2323
futures-channel = { version = "0.3.17", features = ["sink"]}
2424
headers = "0.3.5"
2525
http = "0.2"
26-
hyper = { version = "0.14", features = ["stream", "server", "http1", "http2", "tcp", "client"] }
26+
hyper = { version = "0.14.19", features = ["stream", "server", "http1", "http2", "tcp", "client"] }
2727
log = "0.4"
2828
mime = "0.3"
2929
mime_guess = "2.0.0"

src/filter/service.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@ where
7373
pub(crate) fn call_with_addr(
7474
&self,
7575
req: Request,
76+
local_addr: Option<SocketAddr>,
7677
remote_addr: Option<SocketAddr>,
7778
) -> FilteredFuture<F::Future> {
7879
debug_assert!(!route::is_set(), "nested route::set calls");
7980

80-
let route = Route::new(req, remote_addr);
81+
let route = Route::new(req, local_addr, remote_addr);
8182
let fut = route::set(&route, || self.filter.filter(super::Internal));
8283
FilteredFuture { future: fut, route }
8384
}
@@ -99,7 +100,7 @@ where
99100

100101
#[inline]
101102
fn call(&mut self, req: Request) -> Self::Future {
102-
self.call_with_addr(req, None)
103+
self.call_with_addr(req, None, None)
103104
}
104105
}
105106

src/filters/addr.rs

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@ use std::net::SocketAddr;
55

66
use crate::filter::{filter_fn_one, Filter};
77

8+
/// Creates a `Filter` to get the local address of the connection.
9+
///
10+
/// If the underlying transport doesn't use socket addresses, this will yield
11+
/// `None`.
12+
///
13+
/// # Example
14+
///
15+
/// ```
16+
/// use std::net::SocketAddr;
17+
/// use warp::Filter;
18+
///
19+
/// let route = warp::addr::local()
20+
/// .map(|addr: Option<SocketAddr>| {
21+
/// println!("local address = {:?}", addr);
22+
/// });
23+
/// ```
24+
pub fn local() -> impl Filter<Extract = (Option<SocketAddr>,), Error = Infallible> + Copy {
25+
filter_fn_one(|route| futures_util::future::ok(route.local_addr()))
26+
}
27+
828
/// Creates a `Filter` to get the remote address of the connection.
929
///
1030
/// If the underlying transport doesn't use socket addresses, this will yield

src/filters/log.rs

+5
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ where
109109
}
110110

111111
impl<'a> Info<'a> {
112+
/// View the local `SocketAddr` of the request.
113+
pub fn local_addr(&self) -> Option<SocketAddr> {
114+
self.route.local_addr()
115+
}
116+
112117
/// View the remote `SocketAddr` of the request.
113118
pub fn remote_addr(&self) -> Option<SocketAddr> {
114119
self.route.remote_addr()

src/filters/trace.rs

+5
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ where
157157
}
158158

159159
impl<'a> Info<'a> {
160+
/// View the local `SocketAddr` of the request.
161+
pub fn local_addr(&self) -> Option<SocketAddr> {
162+
self.route.local_addr()
163+
}
164+
160165
/// View the remote `SocketAddr` of the request.
161166
pub fn remote_addr(&self) -> Option<SocketAddr> {
162167
self.route.remote_addr()

src/route.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ where
3030
#[derive(Debug)]
3131
pub(crate) struct Route {
3232
body: BodyState,
33+
local_addr: Option<SocketAddr>,
3334
remote_addr: Option<SocketAddr>,
3435
req: Request,
3536
segments_index: usize,
@@ -42,7 +43,9 @@ enum BodyState {
4243
}
4344

4445
impl Route {
45-
pub(crate) fn new(req: Request, remote_addr: Option<SocketAddr>) -> RefCell<Route> {
46+
pub(crate) fn new(req: Request,
47+
local_addr: Option<SocketAddr>,
48+
remote_addr: Option<SocketAddr>) -> RefCell<Route> {
4649
let segments_index = if req.uri().path().starts_with('/') {
4750
// Skip the beginning slash.
4851
1
@@ -52,6 +55,7 @@ impl Route {
5255

5356
RefCell::new(Route {
5457
body: BodyState::Ready,
58+
local_addr,
5559
remote_addr,
5660
req,
5761
segments_index,
@@ -123,6 +127,10 @@ impl Route {
123127
self.segments_index = index;
124128
}
125129

130+
pub(crate) fn local_addr(&self) -> Option<SocketAddr> {
131+
self.local_addr
132+
}
133+
126134
pub(crate) fn remote_addr(&self) -> Option<SocketAddr> {
127135
self.remote_addr
128136
}

src/server.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ macro_rules! into_service {
5555
let inner = crate::service($into);
5656
make_service_fn(move |transport| {
5757
let inner = inner.clone();
58+
let local_addr = Transport::local_addr(transport);
5859
let remote_addr = Transport::remote_addr(transport);
5960
future::ok::<_, Infallible>(service_fn(move |req| {
60-
inner.call_with_addr(req, remote_addr)
61+
inner.call_with_addr(req, local_addr, remote_addr)
6162
}))
6263
})
6364
}};

src/test.rs

+20-2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ use self::inner::OneOrTuple;
123123
/// Starts a new test `RequestBuilder`.
124124
pub fn request() -> RequestBuilder {
125125
RequestBuilder {
126+
local_addr: None,
126127
remote_addr: None,
127128
req: Request::default(),
128129
}
@@ -140,6 +141,7 @@ pub fn ws() -> WsBuilder {
140141
#[must_use = "RequestBuilder does nothing on its own"]
141142
#[derive(Debug)]
142143
pub struct RequestBuilder {
144+
local_addr: Option<SocketAddr>,
143145
remote_addr: Option<SocketAddr>,
144146
req: Request,
145147
}
@@ -237,6 +239,22 @@ impl RequestBuilder {
237239
self
238240
}
239241

242+
/// Set the local address of this request
243+
///
244+
/// Default is no local address.
245+
///
246+
/// # Example
247+
/// ```
248+
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
249+
///
250+
/// let req = warp::test::request()
251+
/// .local_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
252+
/// ```
253+
pub fn local_addr(mut self, addr: SocketAddr) -> Self {
254+
self.local_addr = Some(addr);
255+
self
256+
}
257+
240258
/// Set the remote address of this request
241259
///
242260
/// Default is no remote address.
@@ -376,7 +394,7 @@ impl RequestBuilder {
376394
// TODO: de-duplicate this and apply_filter()
377395
assert!(!route::is_set(), "nested test filter calls");
378396

379-
let route = Route::new(self.req, self.remote_addr);
397+
let route = Route::new(self.req, self.local_addr, self.remote_addr);
380398
let mut fut = Box::pin(
381399
route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
382400
let res = match result {
@@ -405,7 +423,7 @@ impl RequestBuilder {
405423
{
406424
assert!(!route::is_set(), "nested test filter calls");
407425

408-
let route = Route::new(self.req, self.remote_addr);
426+
let route = Route::new(self.req, self.local_addr, self.remote_addr);
409427
let mut fut = Box::pin(route::set(&route, move || {
410428
f.filter(crate::filter::Internal)
411429
}));

src/tls.rs

+7
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ impl Read for LazyFile {
281281
}
282282

283283
impl Transport for TlsStream {
284+
fn local_addr(&self) -> Option<SocketAddr> {
285+
Some(self.local_addr)
286+
}
287+
284288
fn remote_addr(&self) -> Option<SocketAddr> {
285289
Some(self.remote_addr)
286290
}
@@ -296,15 +300,18 @@ enum State {
296300
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
297301
pub(crate) struct TlsStream {
298302
state: State,
303+
local_addr: SocketAddr,
299304
remote_addr: SocketAddr,
300305
}
301306

302307
impl TlsStream {
303308
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
309+
let local_addr = stream.local_addr();
304310
let remote_addr = stream.remote_addr();
305311
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
306312
TlsStream {
307313
state: State::Handshaking(accept),
314+
local_addr,
308315
remote_addr,
309316
}
310317
}

src/transport.rs

+9
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,15 @@ use hyper::server::conn::AddrStream;
77
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
88

99
pub trait Transport: AsyncRead + AsyncWrite {
10+
fn local_addr(&self) -> Option<SocketAddr>;
1011
fn remote_addr(&self) -> Option<SocketAddr>;
1112
}
1213

1314
impl Transport for AddrStream {
15+
fn local_addr(&self) -> Option<SocketAddr> {
16+
Some(self.local_addr())
17+
}
18+
1419
fn remote_addr(&self) -> Option<SocketAddr> {
1520
Some(self.remote_addr())
1621
}
@@ -47,6 +52,10 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for LiftIo<T> {
4752
}
4853

4954
impl<T: AsyncRead + AsyncWrite + Unpin> Transport for LiftIo<T> {
55+
fn local_addr(&self) -> Option<SocketAddr> {
56+
None
57+
}
58+
5059
fn remote_addr(&self) -> Option<SocketAddr> {
5160
None
5261
}

tests/addr.rs

+21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@
22

33
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
44

5+
#[tokio::test]
6+
async fn local_addr_missing() {
7+
let extract_local_addr = warp::addr::local();
8+
9+
let req = warp::test::request();
10+
let resp = req.filter(&extract_local_addr).await.unwrap();
11+
assert_eq!(resp, None)
12+
}
13+
14+
#[tokio::test]
15+
async fn local_addr_present() {
16+
let extract_local_addr = warp::addr::local();
17+
18+
let req = warp::test::request().local_addr("1.2.3.4:5678".parse().unwrap());
19+
let resp = req.filter(&extract_local_addr).await.unwrap();
20+
assert_eq!(
21+
resp,
22+
Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 5678))
23+
)
24+
}
25+
526
#[tokio::test]
627
async fn remote_addr_missing() {
728
let extract_remote_addr = warp::addr::remote();

0 commit comments

Comments
 (0)