11use std:: { sync:: Arc , time:: Duration } ;
22
3- use axum_extra:: {
4- headers:: { authorization:: Bearer , Authorization } ,
5- TypedHeader ,
6- } ;
73use axum:: {
84 extract:: { self , FromRequestParts } ,
95 http:: request:: Parts ,
106 routing:: { get, post} ,
11- Extension , Json , Router ,
7+ Json , Router ,
8+ } ;
9+ use axum_extra:: {
10+ headers:: { authorization:: Bearer , Authorization } ,
11+ TypedHeader ,
1212} ;
1313use jwt_simple:: prelude:: { Claims , MACLike , NoCustomClaims } ;
1414use kabalist_types:: {
1515 GetAccountNameResponse , LoginRequest , LoginResponse , RecoverPasswordRequest ,
1616 RecoverPasswordResponse , RecoveryInfoResponse , RegisterRequest , RegisterResponse ,
1717} ;
18- use sqlx:: PgPool ;
1918use tokio_stream:: StreamExt ;
2019use uuid:: Uuid ;
2120
22- use crate :: { config :: Config , ok_response:: * , ErrResponse , Error , OkResponse , Rsp } ;
21+ use crate :: { ok_response:: * , ErrResponse , Error , KabalistState , OkResponse , Rsp , State } ;
2322
2423#[ derive( Debug ) ]
2524pub ( crate ) struct User {
2625 pub id : Uuid ,
2726}
2827
29- impl < S > FromRequestParts < S > for User
30- where
31- S : Send + Sync ,
32- {
28+ impl FromRequestParts < Arc < KabalistState > > for User {
3329 type Rejection = Error ;
3430
35- async fn from_request_parts ( parts : & mut Parts , state : & S ) -> Result < Self , Self :: Rejection > {
36- let Extension ( config) = Extension :: < Arc < Config > > :: from_request_parts ( parts, state)
37- . await
38- . map_err ( |e| {
39- tracing:: error!( "Could not fetch config extension: {:?}" , e) ;
40- Error :: Internal
41- } ) ?;
42-
31+ async fn from_request_parts (
32+ parts : & mut Parts ,
33+ state : & Arc < KabalistState > ,
34+ ) -> Result < Self , Self :: Rejection > {
4335 let TypedHeader ( Authorization ( bearer) ) =
4436 TypedHeader :: < Authorization < Bearer > > :: from_request_parts ( parts, state)
4537 . await
4638 . map_err ( |_| Error :: MissingAuthorization ) ?;
4739
48- let claims = config
40+ let claims = state
41+ . config
4942 . jwt_secret
5043 . 0
5144 . verify_token :: < NoCustomClaims > ( bearer. token ( ) , None ) ?;
5750 }
5851}
5952
60- pub ( crate ) fn router ( ) -> Router {
53+ pub ( crate ) fn router ( ) -> Router < Arc < KabalistState > > {
6154 Router :: new ( )
6255 . route ( "/login" , post ( login) )
6356 . route ( "/register/{id}" , post ( register) )
@@ -76,29 +69,25 @@ pub(crate) fn router() -> Router {
7669 ) ,
7770 request_body = LoginRequest ,
7871) ]
79- #[ tracing:: instrument( skip( config, db) ) ]
80- async fn login (
81- Extension ( config) : Extension < Arc < Config > > ,
82- Extension ( db) : Extension < PgPool > ,
83- Json ( request) : Json < LoginRequest > ,
84- ) -> Rsp < LoginResponse > {
72+ #[ tracing:: instrument( skip( state) ) ]
73+ async fn login ( state : State , Json ( request) : Json < LoginRequest > ) -> Rsp < LoginResponse > {
8574 let mut rsp = sqlx:: query!(
8675 "SELECT id FROM accounts WHERE name = $1::text::citext AND password = crypt($2, password)" ,
8776 request. username,
8877 request. password. 0 ,
8978 )
90- . fetch ( & db ) ;
79+ . fetch ( & state . 0 . pool ) ;
9180
9281 let id = match rsp. next ( ) . await {
9382 None => return Err ( Error :: UnknownAccount ) ,
9483 Some ( Err ( e) ) => return Err ( e. into ( ) ) ,
9584 Some ( Ok ( id) ) => id. id ,
9685 } ;
9786
98- let mut claims = Claims :: create ( Duration :: from_millis ( config. exp as _ ) . into ( ) ) ;
87+ let mut claims = Claims :: create ( Duration :: from_millis ( state . 0 . config . exp as _ ) . into ( ) ) ;
9988 claims. subject = Some ( id. to_string ( ) ) ;
10089
101- let token = config. jwt_secret . 0 . authenticate ( claims) ?;
90+ let token = state . 0 . config . jwt_secret . 0 . authenticate ( claims) ?;
10291
10392 OkResponse :: ok ( LoginResponse { token } )
10493}
@@ -116,13 +105,13 @@ async fn login(
116105 ) ,
117106 request_body = RegisterRequest ,
118107) ]
119- #[ tracing:: instrument( skip( db ) ) ]
108+ #[ tracing:: instrument( skip( state ) ) ]
120109async fn register (
121- Extension ( db ) : Extension < PgPool > ,
110+ state : State ,
122111 extract:: Path ( id) : extract:: Path < Uuid > ,
123112 Json ( req) : Json < RegisterRequest > ,
124113) -> Rsp < RegisterResponse > {
125- let mut tx = db . begin ( ) . await ?;
114+ let mut tx = state . 0 . pool . begin ( ) . await ?;
126115
127116 let mut is_registered =
128117 sqlx:: query!( "SELECT id FROM registrations WHERE id = $1" , id) . fetch ( & mut * tx) ;
@@ -163,9 +152,9 @@ async fn register(
163152 ( "id" = Uuid , Path , description = "Recovery ID" ) ,
164153 ) ,
165154) ]
166- #[ tracing:: instrument( skip( db ) ) ]
155+ #[ tracing:: instrument( skip( state ) ) ]
167156async fn recovery_info (
168- Extension ( db ) : Extension < PgPool > ,
157+ state : State ,
169158 extract:: Path ( id) : extract:: Path < Uuid > ,
170159) -> Rsp < RecoveryInfoResponse > {
171160 let username = sqlx:: query!(
@@ -175,7 +164,7 @@ async fn recovery_info(
175164 AND password_reset.account = accounts.id"# ,
176165 id
177166 )
178- . fetch_one ( & db )
167+ . fetch_one ( & state . 0 . pool )
179168 . await ?
180169 . name ;
181170
@@ -199,11 +188,11 @@ async fn recovery_info(
199188 request_body = RecoverPasswordRequest
200189) ]
201190async fn recover_password (
202- Extension ( db ) : Extension < PgPool > ,
191+ state : State ,
203192 extract:: Path ( id) : extract:: Path < Uuid > ,
204193 Json ( request) : Json < RecoverPasswordRequest > ,
205194) -> Rsp < RecoverPasswordResponse > {
206- let mut tx = db . begin ( ) . await ?;
195+ let mut tx = state . 0 . pool . begin ( ) . await ?;
207196
208197 let account = sqlx:: query!(
209198 "SELECT password_reset.account FROM password_reset WHERE id = $1" ,
@@ -245,12 +234,12 @@ async fn recover_password(
245234 )
246235) ]
247236async fn get_account_name (
248- Extension ( db ) : Extension < PgPool > ,
237+ state : State ,
249238 _user : User ,
250239 extract:: Path ( id) : extract:: Path < Uuid > ,
251240) -> Rsp < GetAccountNameResponse > {
252241 let name = sqlx:: query!( "SELECT name::text FROM accounts WHERE id = $1" , id)
253- . fetch_one ( & db )
242+ . fetch_one ( & state . 0 . pool )
254243 . await ?
255244 . name ;
256245
0 commit comments