@@ -9,6 +9,7 @@ use std::{
99 future:: Future ,
1010 marker:: PhantomData ,
1111 pin:: Pin ,
12+ sync:: Arc ,
1213 task:: { Context , Poll } ,
1314} ;
1415use tower:: { util:: BoxCloneService , ServiceBuilder } ;
@@ -90,82 +91,57 @@ use tower_service::Service;
9091/// # let app: Router = app;
9192/// ```
9293///
93- /// # Passing state
94- ///
95- /// State can be passed to the function like so:
96- ///
97- /// ```rust
98- /// use axum::{
99- /// Router,
100- /// http::{Request, StatusCode},
101- /// routing::get,
102- /// response::{IntoResponse, Response},
103- /// middleware::{self, Next}
104- /// };
105- ///
106- /// #[derive(Clone)]
107- /// struct State { /* ... */ }
108- ///
109- /// async fn my_middleware<B>(
110- /// req: Request<B>,
111- /// next: Next<B>,
112- /// state: State,
113- /// ) -> Response {
114- /// // ...
115- /// # ().into_response()
116- /// }
117- ///
118- /// let state = State { /* ... */ };
94+ /// [extractors]: crate::extract::FromRequest
95+ pub fn from_fn < F , T > ( f : F ) -> FromFnLayer < F , ( ) , T > {
96+ from_fn_with_state ( ( ) , f)
97+ }
98+
99+ /// Create a middleware from an async function with the given state.
119100///
120- /// let app = Router::new()
121- /// .route("/", get(|| async { /* ... */ }))
122- /// .route_layer(middleware::from_fn(move |req, next| {
123- /// my_middleware(req, next, state.clone())
124- /// }));
125- /// # let app: Router = app;
126- /// ```
101+ /// See [`State`](crate::extract::State) for more details about accessing state.
127102///
128- /// Or via extensions:
103+ /// # Example
129104///
130105/// ```rust
131106/// use axum::{
132107/// Router,
133- /// extract::Extension,
134108/// http::{Request, StatusCode},
135109/// routing::get,
136110/// response::{IntoResponse, Response},
137111/// middleware::{self, Next},
112+ /// extract::State,
138113/// };
139- /// use tower::ServiceBuilder;
140114///
141115/// #[derive(Clone)]
142- /// struct State { /* ... */ }
116+ /// struct AppState { /* ... */ }
143117///
144118/// async fn my_middleware<B>(
145- /// Extension (state): Extension< State>,
119+ /// State (state): State<AppState >,
146120/// req: Request<B>,
147121/// next: Next<B>,
148122/// ) -> Response {
149123/// // ...
150124/// # ().into_response()
151125/// }
152126///
153- /// let state = State { /* ... */ };
127+ /// let state = AppState { /* ... */ };
154128///
155- /// let app = Router::new( )
129+ /// let app = Router::with_state(state.clone() )
156130/// .route("/", get(|| async { /* ... */ }))
157- /// .layer(
158- /// ServiceBuilder::new()
159- /// .layer(Extension(state))
160- /// .layer(middleware::from_fn(my_middleware)),
161- /// );
162- /// # let app: Router = app;
131+ /// .route_layer(middleware::from_fn_with_state(state, my_middleware));
132+ /// # let app: Router<_> = app;
163133/// ```
134+ pub fn from_fn_with_state < F , S , T > ( state : S , f : F ) -> FromFnLayer < F , S , T > {
135+ from_fn_with_state_arc ( Arc :: new ( state) , f)
136+ }
137+
138+ /// Create a middleware from an async function with the given [`Arc`]'ed state.
164139///
165- /// [extractors]: crate::extract::FromRequest
166- pub fn from_fn < F , T > ( f : F ) -> FromFnLayer < F , T > {
140+ /// See [`State`]( crate::extract::State) for more details about accessing state.
141+ pub fn from_fn_with_state_arc < F , S , T > ( state : Arc < S > , f : F ) -> FromFnLayer < F , S , T > {
167142 FromFnLayer {
168143 f,
144+ state,
169145 _extractor : PhantomData ,
170146 }
171147}
@@ -175,98 +151,99 @@ pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, T> {
175151/// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s.
176152///
177153/// Created with [`from_fn`]. See that function for more details.
178- pub struct FromFnLayer < F , T > {
154+ pub struct FromFnLayer < F , S , T > {
179155 f : F ,
156+ state : Arc < S > ,
180157 _extractor : PhantomData < fn ( ) -> T > ,
181158}
182159
183- impl < F , T > Clone for FromFnLayer < F , T >
160+ impl < F , S , T > Clone for FromFnLayer < F , S , T >
184161where
185162 F : Clone ,
186163{
187164 fn clone ( & self ) -> Self {
188165 Self {
189166 f : self . f . clone ( ) ,
167+ state : Arc :: clone ( & self . state ) ,
190168 _extractor : self . _extractor ,
191169 }
192170 }
193171}
194172
195- impl < F , T > Copy for FromFnLayer < F , T > where F : Copy { }
196-
197- impl < S , F , T > Layer < S > for FromFnLayer < F , T >
173+ impl < S , I , F , T > Layer < I > for FromFnLayer < F , S , T >
198174where
199175 F : Clone ,
200176{
201- type Service = FromFn < F , S , T > ;
177+ type Service = FromFn < F , S , I , T > ;
202178
203- fn layer ( & self , inner : S ) -> Self :: Service {
179+ fn layer ( & self , inner : I ) -> Self :: Service {
204180 FromFn {
205181 f : self . f . clone ( ) ,
182+ state : Arc :: clone ( & self . state ) ,
206183 inner,
207184 _extractor : PhantomData ,
208185 }
209186 }
210187}
211188
212- impl < F , T > fmt:: Debug for FromFnLayer < F , T > {
189+ impl < F , S , T > fmt:: Debug for FromFnLayer < F , S , T >
190+ where
191+ S : fmt:: Debug ,
192+ {
213193 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
214194 f. debug_struct ( "FromFnLayer" )
215195 // Write out the type name, without quoting it as `&type_name::<F>()` would
216196 . field ( "f" , & format_args ! ( "{}" , type_name:: <F >( ) ) )
197+ . field ( "state" , & self . state )
217198 . finish ( )
218199 }
219200}
220201
221202/// A middleware created from an async function.
222203///
223204/// Created with [`from_fn`]. See that function for more details.
224- pub struct FromFn < F , S , T > {
205+ pub struct FromFn < F , S , I , T > {
225206 f : F ,
226- inner : S ,
207+ inner : I ,
208+ state : Arc < S > ,
227209 _extractor : PhantomData < fn ( ) -> T > ,
228210}
229211
230- impl < F , S , T > Clone for FromFn < F , S , T >
212+ impl < F , S , I , T > Clone for FromFn < F , S , I , T >
231213where
232214 F : Clone ,
233- S : Clone ,
215+ I : Clone ,
234216{
235217 fn clone ( & self ) -> Self {
236218 Self {
237219 f : self . f . clone ( ) ,
238220 inner : self . inner . clone ( ) ,
221+ state : Arc :: clone ( & self . state ) ,
239222 _extractor : self . _extractor ,
240223 }
241224 }
242225}
243226
244- impl < F , S , T > Copy for FromFn < F , S , T >
245- where
246- F : Copy ,
247- S : Copy ,
248- {
249- }
250-
251227macro_rules! impl_service {
252228 (
253229 [ $( $ty: ident) ,* ] , $last: ident
254230 ) => {
255231 #[ allow( non_snake_case, unused_mut) ]
256- impl <F , Fut , Out , S , B , $( $ty, ) * $last> Service <Request <B >> for FromFn <F , S , ( $( $ty, ) * $last, ) >
232+ impl <F , Fut , Out , S , I , B , $( $ty, ) * $last> Service <Request <B >> for FromFn <F , S , I , ( $( $ty, ) * $last, ) >
257233 where
258234 F : FnMut ( $( $ty, ) * $last, Next <B >) -> Fut + Clone + Send + ' static ,
259- $( $ty: FromRequestParts <( ) > + Send , ) *
260- $last: FromRequest <( ) , B > + Send ,
235+ $( $ty: FromRequestParts <S > + Send , ) *
236+ $last: FromRequest <S , B > + Send ,
261237 Fut : Future <Output = Out > + Send + ' static ,
262238 Out : IntoResponse + ' static ,
263- S : Service <Request <B >, Error = Infallible >
239+ I : Service <Request <B >, Error = Infallible >
264240 + Clone
265241 + Send
266242 + ' static ,
267- S :: Response : IntoResponse ,
268- S :: Future : Send + ' static ,
243+ I :: Response : IntoResponse ,
244+ I :: Future : Send + ' static ,
269245 B : Send + ' static ,
246+ S : Send + Sync + ' static ,
270247 {
271248 type Response = Response ;
272249 type Error = Infallible ;
@@ -281,20 +258,21 @@ macro_rules! impl_service {
281258 let ready_inner = std:: mem:: replace( & mut self . inner, not_ready_inner) ;
282259
283260 let mut f = self . f. clone( ) ;
261+ let state = Arc :: clone( & self . state) ;
284262
285263 let future = Box :: pin( async move {
286264 let ( mut parts, body) = req. into_parts( ) ;
287265
288266 $(
289- let $ty = match $ty:: from_request_parts( & mut parts, & ( ) ) . await {
267+ let $ty = match $ty:: from_request_parts( & mut parts, & state ) . await {
290268 Ok ( value) => value,
291269 Err ( rejection) => return rejection. into_response( ) ,
292270 } ;
293271 ) *
294272
295273 let req = Request :: from_parts( parts, body) ;
296274
297- let $last = match $last:: from_request( req, & ( ) ) . await {
275+ let $last = match $last:: from_request( req, & state ) . await {
298276 Ok ( value) => value,
299277 Err ( rejection) => return rejection. into_response( ) ,
300278 } ;
@@ -342,14 +320,16 @@ impl_service!(
342320 T16
343321) ;
344322
345- impl < F , S , T > fmt:: Debug for FromFn < F , S , T >
323+ impl < F , S , I , T > fmt:: Debug for FromFn < F , S , I , T >
346324where
347325 S : fmt:: Debug ,
326+ I : fmt:: Debug ,
348327{
349328 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
350329 f. debug_struct ( "FromFnLayer" )
351330 . field ( "f" , & format_args ! ( "{}" , type_name:: <F >( ) ) )
352331 . field ( "inner" , & self . inner )
332+ . field ( "state" , & self . state )
353333 . finish ( )
354334 }
355335}
0 commit comments