Skip to content

Commit 060ff8c

Browse files
committed
Switch from using Extension to State
1 parent 38447d2 commit 060ff8c

File tree

5 files changed

+157
-158
lines changed

5 files changed

+157
-158
lines changed

api/src/account.rs

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,44 @@
11
use std::{sync::Arc, time::Duration};
22

3-
use axum_extra::{
4-
headers::{authorization::Bearer, Authorization},
5-
TypedHeader,
6-
};
73
use axum::{
84
extract::{self, FromRequestParts},
95
http::request::Parts,
106
routing::{get, post},
11-
Extension, Json, Router,
7+
Json, Router,
8+
};
9+
use axum_extra::{
10+
headers::{authorization::Bearer, Authorization},
11+
TypedHeader,
1212
};
1313
use jwt_simple::prelude::{Claims, MACLike, NoCustomClaims};
1414
use kabalist_types::{
1515
GetAccountNameResponse, LoginRequest, LoginResponse, RecoverPasswordRequest,
1616
RecoverPasswordResponse, RecoveryInfoResponse, RegisterRequest, RegisterResponse,
1717
};
18-
use sqlx::PgPool;
1918
use tokio_stream::StreamExt;
2019
use uuid::Uuid;
2120

22-
use crate::{config::Config, ok_response::*, ErrResponse, Error, OkResponse, Rsp};
21+
use crate::{ok_response::*, ErrResponse, Error, KabalistState, OkResponse, Rsp, State};
2322

2423
#[derive(Debug)]
2524
pub(crate) struct User {
2625
pub id: Uuid,
2726
}
2827

29-
impl<S> FromRequestParts<S> for User
30-
where
31-
S: Send + Sync,
32-
{
28+
impl FromRequestParts<Arc<KabalistState>> for User {
3329
type Rejection = Error;
3430

35-
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
36-
let Extension(config) = Extension::<Arc<Config>>::from_request_parts(parts, state)
37-
.await
38-
.map_err(|e| {
39-
tracing::error!("Could not fetch config extension: {:?}", e);
40-
Error::Internal
41-
})?;
42-
31+
async fn from_request_parts(
32+
parts: &mut Parts,
33+
state: &Arc<KabalistState>,
34+
) -> Result<Self, Self::Rejection> {
4335
let TypedHeader(Authorization(bearer)) =
4436
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
4537
.await
4638
.map_err(|_| Error::MissingAuthorization)?;
4739

48-
let claims = config
40+
let claims = state
41+
.config
4942
.jwt_secret
5043
.0
5144
.verify_token::<NoCustomClaims>(bearer.token(), None)?;
@@ -57,7 +50,7 @@ where
5750
}
5851
}
5952

60-
pub(crate) fn router() -> Router {
53+
pub(crate) fn router() -> Router<Arc<KabalistState>> {
6154
Router::new()
6255
.route("/login", post(login))
6356
.route("/register/{id}", post(register))
@@ -76,29 +69,25 @@ pub(crate) fn router() -> Router {
7669
),
7770
request_body = LoginRequest,
7871
)]
79-
#[tracing::instrument(skip(config, db))]
80-
async fn login(
81-
Extension(config): Extension<Arc<Config>>,
82-
Extension(db): Extension<PgPool>,
83-
Json(request): Json<LoginRequest>,
84-
) -> Rsp<LoginResponse> {
72+
#[tracing::instrument(skip(state))]
73+
async fn login(state: State, Json(request): Json<LoginRequest>) -> Rsp<LoginResponse> {
8574
let mut rsp = sqlx::query!(
8675
"SELECT id FROM accounts WHERE name = $1::text::citext AND password = crypt($2, password)",
8776
request.username,
8877
request.password.0,
8978
)
90-
.fetch(&db);
79+
.fetch(&state.0.pool);
9180

9281
let id = match rsp.next().await {
9382
None => return Err(Error::UnknownAccount),
9483
Some(Err(e)) => return Err(e.into()),
9584
Some(Ok(id)) => id.id,
9685
};
9786

98-
let mut claims = Claims::create(Duration::from_millis(config.exp as _).into());
87+
let mut claims = Claims::create(Duration::from_millis(state.0.config.exp as _).into());
9988
claims.subject = Some(id.to_string());
10089

101-
let token = config.jwt_secret.0.authenticate(claims)?;
90+
let token = state.0.config.jwt_secret.0.authenticate(claims)?;
10291

10392
OkResponse::ok(LoginResponse { token })
10493
}
@@ -116,13 +105,13 @@ async fn login(
116105
),
117106
request_body = RegisterRequest,
118107
)]
119-
#[tracing::instrument(skip(db))]
108+
#[tracing::instrument(skip(state))]
120109
async fn register(
121-
Extension(db): Extension<PgPool>,
110+
state: State,
122111
extract::Path(id): extract::Path<Uuid>,
123112
Json(req): Json<RegisterRequest>,
124113
) -> Rsp<RegisterResponse> {
125-
let mut tx = db.begin().await?;
114+
let mut tx = state.0.pool.begin().await?;
126115

127116
let mut is_registered =
128117
sqlx::query!("SELECT id FROM registrations WHERE id = $1", id).fetch(&mut *tx);
@@ -163,9 +152,9 @@ async fn register(
163152
("id" = Uuid, Path, description = "Recovery ID"),
164153
),
165154
)]
166-
#[tracing::instrument(skip(db))]
155+
#[tracing::instrument(skip(state))]
167156
async fn recovery_info(
168-
Extension(db): Extension<PgPool>,
157+
state: State,
169158
extract::Path(id): extract::Path<Uuid>,
170159
) -> Rsp<RecoveryInfoResponse> {
171160
let username = sqlx::query!(
@@ -175,7 +164,7 @@ async fn recovery_info(
175164
AND password_reset.account = accounts.id"#,
176165
id
177166
)
178-
.fetch_one(&db)
167+
.fetch_one(&state.0.pool)
179168
.await?
180169
.name;
181170

@@ -199,11 +188,11 @@ async fn recovery_info(
199188
request_body = RecoverPasswordRequest
200189
)]
201190
async fn recover_password(
202-
Extension(db): Extension<PgPool>,
191+
state: State,
203192
extract::Path(id): extract::Path<Uuid>,
204193
Json(request): Json<RecoverPasswordRequest>,
205194
) -> Rsp<RecoverPasswordResponse> {
206-
let mut tx = db.begin().await?;
195+
let mut tx = state.0.pool.begin().await?;
207196

208197
let account = sqlx::query!(
209198
"SELECT password_reset.account FROM password_reset WHERE id = $1",
@@ -245,12 +234,12 @@ async fn recover_password(
245234
)
246235
)]
247236
async fn get_account_name(
248-
Extension(db): Extension<PgPool>,
237+
state: State,
249238
_user: User,
250239
extract::Path(id): extract::Path<Uuid>,
251240
) -> Rsp<GetAccountNameResponse> {
252241
let name = sqlx::query!("SELECT name::text FROM accounts WHERE id = $1", id)
253-
.fetch_one(&db)
242+
.fetch_one(&state.0.pool)
254243
.await?
255244
.name;
256245

0 commit comments

Comments
 (0)