Skip to content

Commit 4c9edb4

Browse files
authored
Add middleware::{from_fn_with_state, from_fn_with_state_arc} (#1342)
1 parent 3f92f7d commit 4c9edb4

File tree

4 files changed

+73
-114
lines changed

4 files changed

+73
-114
lines changed

axum/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
without any routes will now result in a panic. Previously, this just did
1414
nothing. [#1327]
1515

16+
## Middleware
17+
18+
- **added**: Add `middleware::from_fn_with_state` and
19+
`middleware::from_fn_with_state_arc` to enable running extractors that require
20+
state ([#1342])
21+
1622
[#1327]: https://github.com/tokio-rs/axum/pull/1327
23+
[#1342]: https://github.com/tokio-rs/axum/pull/1342
1724

1825
# 0.6.0-rc.1 (23. August, 2022)
1926

axum/src/docs/middleware.md

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -390,45 +390,12 @@ middleware you don't have to worry about any of this.
390390

391391
# Accessing state in middleware
392392

393-
Handlers can access state using the [`State`] extractor but this isn't available
394-
to middleware. Instead you have to pass the state directly to middleware using
395-
either closure captures (for [`axum::middleware::from_fn`]) or regular struct
396-
fields (if you're implementing a [`tower::Layer`])
393+
How to make state available to middleware depends on how the middleware is
394+
written.
397395

398396
## Accessing state in `axum::middleware::from_fn`
399397

400-
```rust
401-
use axum::{
402-
Router,
403-
routing::get,
404-
middleware::{self, Next},
405-
response::Response,
406-
extract::State,
407-
http::Request,
408-
};
409-
410-
#[derive(Clone)]
411-
struct AppState {}
412-
413-
async fn my_middleware<B>(
414-
state: AppState,
415-
req: Request<B>,
416-
next: Next<B>,
417-
) -> Response {
418-
next.run(req).await
419-
}
420-
421-
async fn handler(_: State<AppState>) {}
422-
423-
let state = AppState {};
424-
425-
let app = Router::with_state(state.clone())
426-
.route("/", get(handler))
427-
.layer(middleware::from_fn(move |req, next| {
428-
my_middleware(state.clone(), req, next)
429-
}));
430-
# let _: Router<_> = app;
431-
```
398+
Use [`axum::middleware::from_fn_with_state`](crate::middleware::from_fn_with_state).
432399

433400
## Accessing state in custom `tower::Layer`s
434401

@@ -482,7 +449,10 @@ where
482449
}
483450

484451
fn call(&mut self, req: Request<B>) -> Self::Future {
485-
// do something with `self.state`
452+
// Do something with `self.state`.
453+
//
454+
// See `axum::RequestExt` for how to run extractors directly from
455+
// a `Request`.
486456

487457
self.inner.call(req)
488458
}

axum/src/middleware/from_fn.rs

Lines changed: 56 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::{
99
future::Future,
1010
marker::PhantomData,
1111
pin::Pin,
12+
sync::Arc,
1213
task::{Context, Poll},
1314
};
1415
use tower::{util::BoxCloneService, ServiceBuilder};
@@ -90,82 +91,57 @@ use tower_service::Service;
9091
/// # let app: Router = app;
9192
/// ```
9293
///
93-
/// # Passing state
94-
///
95-
/// State can be passed to the function like so:
96-
///
97-
/// ```rust
98-
/// use axum::{
99-
/// Router,
100-
/// http::{Request, StatusCode},
101-
/// routing::get,
102-
/// response::{IntoResponse, Response},
103-
/// middleware::{self, Next}
104-
/// };
105-
///
106-
/// #[derive(Clone)]
107-
/// struct State { /* ... */ }
108-
///
109-
/// async fn my_middleware<B>(
110-
/// req: Request<B>,
111-
/// next: Next<B>,
112-
/// state: State,
113-
/// ) -> Response {
114-
/// // ...
115-
/// # ().into_response()
116-
/// }
117-
///
118-
/// let state = State { /* ... */ };
94+
/// [extractors]: crate::extract::FromRequest
95+
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
96+
from_fn_with_state((), f)
97+
}
98+
99+
/// Create a middleware from an async function with the given state.
119100
///
120-
/// let app = Router::new()
121-
/// .route("/", get(|| async { /* ... */ }))
122-
/// .route_layer(middleware::from_fn(move |req, next| {
123-
/// my_middleware(req, next, state.clone())
124-
/// }));
125-
/// # let app: Router = app;
126-
/// ```
101+
/// See [`State`](crate::extract::State) for more details about accessing state.
127102
///
128-
/// Or via extensions:
103+
/// # Example
129104
///
130105
/// ```rust
131106
/// use axum::{
132107
/// Router,
133-
/// extract::Extension,
134108
/// http::{Request, StatusCode},
135109
/// routing::get,
136110
/// response::{IntoResponse, Response},
137111
/// middleware::{self, Next},
112+
/// extract::State,
138113
/// };
139-
/// use tower::ServiceBuilder;
140114
///
141115
/// #[derive(Clone)]
142-
/// struct State { /* ... */ }
116+
/// struct AppState { /* ... */ }
143117
///
144118
/// async fn my_middleware<B>(
145-
/// Extension(state): Extension<State>,
119+
/// State(state): State<AppState>,
146120
/// req: Request<B>,
147121
/// next: Next<B>,
148122
/// ) -> Response {
149123
/// // ...
150124
/// # ().into_response()
151125
/// }
152126
///
153-
/// let state = State { /* ... */ };
127+
/// let state = AppState { /* ... */ };
154128
///
155-
/// let app = Router::new()
129+
/// let app = Router::with_state(state.clone())
156130
/// .route("/", get(|| async { /* ... */ }))
157-
/// .layer(
158-
/// ServiceBuilder::new()
159-
/// .layer(Extension(state))
160-
/// .layer(middleware::from_fn(my_middleware)),
161-
/// );
162-
/// # let app: Router = app;
131+
/// .route_layer(middleware::from_fn_with_state(state, my_middleware));
132+
/// # let app: Router<_> = app;
163133
/// ```
134+
pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
135+
from_fn_with_state_arc(Arc::new(state), f)
136+
}
137+
138+
/// Create a middleware from an async function with the given [`Arc`]'ed state.
164139
///
165-
/// [extractors]: crate::extract::FromRequest
166-
pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
140+
/// See [`State`](crate::extract::State) for more details about accessing state.
141+
pub fn from_fn_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> FromFnLayer<F, S, T> {
167142
FromFnLayer {
168143
f,
144+
state,
169145
_extractor: PhantomData,
170146
}
171147
}
@@ -175,98 +151,99 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
175151
/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
176152
///
177153
/// Created with [`from_fn`]. See that function for more details.
178-
pub struct FromFnLayer<F, T> {
154+
pub struct FromFnLayer<F, S, T> {
179155
f: F,
156+
state: Arc<S>,
180157
_extractor: PhantomData<fn() -> T>,
181158
}
182159

183-
impl<F, T> Clone for FromFnLayer<F, T>
160+
impl<F, S, T> Clone for FromFnLayer<F, S, T>
184161
where
185162
F: Clone,
186163
{
187164
fn clone(&self) -> Self {
188165
Self {
189166
f: self.f.clone(),
167+
state: Arc::clone(&self.state),
190168
_extractor: self._extractor,
191169
}
192170
}
193171
}
194172

195-
impl<F, T> Copy for FromFnLayer<F, T> where F: Copy {}
196-
197-
impl<S, F, T> Layer<S> for FromFnLayer<F, T>
173+
impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
198174
where
199175
F: Clone,
200176
{
201-
type Service = FromFn<F, S, T>;
177+
type Service = FromFn<F, S, I, T>;
202178

203-
fn layer(&self, inner: S) -> Self::Service {
179+
fn layer(&self, inner: I) -> Self::Service {
204180
FromFn {
205181
f: self.f.clone(),
182+
state: Arc::clone(&self.state),
206183
inner,
207184
_extractor: PhantomData,
208185
}
209186
}
210187
}
211188

212-
impl<F, T> fmt::Debug for FromFnLayer<F, T> {
189+
impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
190+
where
191+
S: fmt::Debug,
192+
{
213193
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214194
f.debug_struct("FromFnLayer")
215195
// Write out the type name, without quoting it as `&type_name::<F>()` would
216196
.field("f", &format_args!("{}", type_name::<F>()))
197+
.field("state", &self.state)
217198
.finish()
218199
}
219200
}
220201

221202
/// A middleware created from an async function.
222203
///
223204
/// Created with [`from_fn`]. See that function for more details.
224-
pub struct FromFn<F, S, T> {
205+
pub struct FromFn<F, S, I, T> {
225206
f: F,
226-
inner: S,
207+
inner: I,
208+
state: Arc<S>,
227209
_extractor: PhantomData<fn() -> T>,
228210
}
229211

230-
impl<F, S, T> Clone for FromFn<F, S, T>
212+
impl<F, S, I, T> Clone for FromFn<F, S, I, T>
231213
where
232214
F: Clone,
233-
S: Clone,
215+
I: Clone,
234216
{
235217
fn clone(&self) -> Self {
236218
Self {
237219
f: self.f.clone(),
238220
inner: self.inner.clone(),
221+
state: Arc::clone(&self.state),
239222
_extractor: self._extractor,
240223
}
241224
}
242225
}
243226

244-
impl<F, S, T> Copy for FromFn<F, S, T>
245-
where
246-
F: Copy,
247-
S: Copy,
248-
{
249-
}
250-
251227
macro_rules! impl_service {
252228
(
253229
[$($ty:ident),*], $last:ident
254230
) => {
255231
#[allow(non_snake_case, unused_mut)]
256-
impl<F, Fut, Out, S, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, ($($ty,)* $last,)>
232+
impl<F, Fut, Out, S, I, B, $($ty,)* $last> Service<Request<B>> for FromFn<F, S, I, ($($ty,)* $last,)>
257233
where
258234
F: FnMut($($ty,)* $last, Next<B>) -> Fut + Clone + Send + 'static,
259-
$( $ty: FromRequestParts<()> + Send, )*
260-
$last: FromRequest<(), B> + Send,
235+
$( $ty: FromRequestParts<S> + Send, )*
236+
$last: FromRequest<S, B> + Send,
261237
Fut: Future<Output = Out> + Send + 'static,
262238
Out: IntoResponse + 'static,
263-
S: Service<Request<B>, Error = Infallible>
239+
I: Service<Request<B>, Error = Infallible>
264240
+ Clone
265241
+ Send
266242
+ 'static,
267-
S::Response: IntoResponse,
268-
S::Future: Send + 'static,
243+
I::Response: IntoResponse,
244+
I::Future: Send + 'static,
269245
B: Send + 'static,
246+
S: Send + Sync + 'static,
270247
{
271248
type Response = Response;
272249
type Error = Infallible;
@@ -281,20 +258,21 @@ macro_rules! impl_service {
281258
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
282259

283260
let mut f = self.f.clone();
261+
let state = Arc::clone(&self.state);
284262

285263
let future = Box::pin(async move {
286264
let (mut parts, body) = req.into_parts();
287265

288266
$(
289-
let $ty = match $ty::from_request_parts(&mut parts, &()).await {
267+
let $ty = match $ty::from_request_parts(&mut parts, &state).await {
290268
Ok(value) => value,
291269
Err(rejection) => return rejection.into_response(),
292270
};
293271
)*
294272

295273
let req = Request::from_parts(parts, body);
296274

297-
let $last = match $last::from_request(req, &()).await {
275+
let $last = match $last::from_request(req, &state).await {
298276
Ok(value) => value,
299277
Err(rejection) => return rejection.into_response(),
300278
};
@@ -342,14 +320,16 @@ impl_service!(
342320
T16
343321
);
344322

345-
impl<F, S, T> fmt::Debug for FromFn<F, S, T>
323+
impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
346324
where
347325
S: fmt::Debug,
326+
I: fmt::Debug,
348327
{
349328
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350329
f.debug_struct("FromFnLayer")
351330
.field("f", &format_args!("{}", type_name::<F>()))
352331
.field("inner", &self.inner)
332+
.field("state", &self.state)
353333
.finish()
354334
}
355335
}

axum/src/middleware/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ mod from_extractor;
66
mod from_fn;
77

88
pub use self::from_extractor::{from_extractor, FromExtractor, FromExtractorLayer};
9-
pub use self::from_fn::{from_fn, FromFn, FromFnLayer, Next};
9+
pub use self::from_fn::{
10+
from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next,
11+
};
1012
pub use crate::extension::AddExtension;
1113

1214
pub mod future {

0 commit comments

Comments
 (0)