1
1
import jwt
2
- from typing import Annotated
2
+ from typing import Dict , Annotated , cast , Any
3
3
from passlib .hash import pbkdf2_sha256
4
4
from jwt .exceptions import InvalidTokenError
5
+
5
6
from datetime import datetime , timedelta , timezone
6
7
7
8
from fastapi import Depends , HTTPException , status
8
9
from fastapi .security import OAuth2PasswordBearer
10
+ from sqlalchemy .sql import ColumnElement
9
11
from sqlalchemy .orm import Session
10
12
from sqlalchemy import select
11
13
20
22
oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "/users/token" )
21
23
22
24
23
- def create_access_token (data : dict , expires_delta : timedelta | None = None ):
25
+ def create_access_token (
26
+ data : Dict [str , str | datetime ], expires_delta : timedelta | None = None
27
+ ) -> str :
24
28
to_encode = data .copy ()
25
29
if expires_delta :
26
30
expire = datetime .now (timezone .utc ) + expires_delta
27
31
else :
28
32
expire = datetime .now (timezone .utc ) + timedelta (minutes = 15 )
29
33
to_encode .update ({"exp" : expire })
30
- encoded_jwt = jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM )
34
+ encoded_jwt = str ( jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM ) )
31
35
return encoded_jwt
32
36
33
37
34
38
def hash_password (password : str ) -> str :
35
- return pbkdf2_sha256 .hash (password )
39
+ return str ( pbkdf2_sha256 .hash (password ) )
36
40
37
41
38
42
async def get_current_user (
39
43
* ,
40
44
session : Session = Depends (get_session ),
41
45
token : Annotated [str , Depends (oauth2_scheme )],
42
- ):
46
+ ) -> Any :
43
47
credentials_exception = HTTPException (
44
48
status_code = status .HTTP_401_UNAUTHORIZED ,
45
49
detail = "Could not validate credentials" ,
@@ -54,16 +58,18 @@ async def get_current_user(
54
58
except InvalidTokenError :
55
59
raise credentials_exception
56
60
57
- stmt = select (User ).where (User .username == token_data .username )
58
- db_user = session .exec (stmt ).first ()[0 ]
61
+ stmt = select (User ).where (
62
+ cast ("ColumnElement[bool]" , User .username == token_data .username )
63
+ )
64
+ db_user = session .execute (stmt ).one ()
59
65
if not db_user :
60
66
raise credentials_exception
61
67
return db_user
62
68
63
69
64
70
async def get_current_active_user (
65
71
current_user : Annotated [User , Depends (get_current_user )],
66
- ):
72
+ ) -> User :
67
73
# if current_user.disabled:
68
74
# raise HTTPException(status_code=400, detail="Inactive user")
69
75
return current_user
0 commit comments