Skip to content

Commit f876162

Browse files
authored
Update the JWT for easier scaling (#1011)
* Update the JWT for easier scaling * Fix comment
1 parent 362a559 commit f876162

File tree

2 files changed

+50
-25
lines changed

2 files changed

+50
-25
lines changed

backend/common/security/jwt.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,30 @@ async def get_current_user(db: AsyncSession, pk: int) -> User:
217217
return user
218218

219219

220+
async def get_jwt_user(user_id: int) -> GetUserInfoWithRelationDetail:
221+
"""
222+
获取 JWT 用户
223+
224+
:param user_id:
225+
:return:
226+
"""
227+
cache_user = await redis_client.get(f'{settings.JWT_USER_REDIS_PREFIX}:{user_id}')
228+
if not cache_user:
229+
async with async_db_session() as db:
230+
current_user = await get_current_user(db, user_id)
231+
user = GetUserInfoWithRelationDetail.model_validate(current_user)
232+
await redis_client.setex(
233+
f'{settings.JWT_USER_REDIS_PREFIX}:{user_id}',
234+
settings.TOKEN_EXPIRE_SECONDS,
235+
user.model_dump_json(),
236+
)
237+
else:
238+
# TODO: 在恰当的时机,应替换为使用 model_validate_json
239+
# https://docs.pydantic.dev/latest/concepts/json/#partial-json-parsing
240+
user = GetUserInfoWithRelationDetail.model_validate(from_json(cache_user, allow_partial=True))
241+
return user
242+
243+
220244
def superuser_verify(request: Request, _token: str = DependsJwtAuth) -> bool:
221245
"""
222246
验证当前用户超级管理员权限
@@ -247,21 +271,7 @@ async def jwt_authentication(token: str) -> GetUserInfoWithRelationDetail:
247271
if token != redis_token:
248272
raise errors.TokenError(msg='Token 已失效')
249273

250-
cache_user = await redis_client.get(f'{settings.JWT_USER_REDIS_PREFIX}:{user_id}')
251-
if not cache_user:
252-
async with async_db_session() as db:
253-
current_user = await get_current_user(db, user_id)
254-
user = GetUserInfoWithRelationDetail.model_validate(current_user)
255-
await redis_client.setex(
256-
f'{settings.JWT_USER_REDIS_PREFIX}:{user_id}',
257-
settings.TOKEN_EXPIRE_SECONDS,
258-
user.model_dump_json(),
259-
)
260-
else:
261-
# TODO: 在恰当的时机,应替换为使用 model_validate_json
262-
# https://docs.pydantic.dev/latest/concepts/json/#partial-json-parsing
263-
user = GetUserInfoWithRelationDetail.model_validate(from_json(cache_user, allow_partial=True))
264-
return user
274+
return await get_jwt_user(user_id)
265275

266276

267277
# 超级管理员鉴权依赖注入

backend/middleware/jwt_auth_middleware.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from fastapi import Request, Response
44
from fastapi.security.utils import get_authorization_scheme_param
5-
from starlette.authentication import AuthCredentials, AuthenticationBackend, AuthenticationError
5+
from starlette.authentication import AuthCredentials, AuthenticationBackend
6+
from starlette.authentication import AuthenticationError as StarletteAuthenticationError
67
from starlette.requests import HTTPConnection
78

89
from backend.app.admin.schema.user import GetUserInfoWithRelationDetail
@@ -13,7 +14,7 @@
1314
from backend.utils.serializers import MsgSpecJSONResponse
1415

1516

16-
class _AuthenticationError(AuthenticationError):
17+
class AuthenticationError(StarletteAuthenticationError):
1718
"""重写内部认证错误类"""
1819

1920
def __init__(
@@ -40,7 +41,7 @@ class JwtAuthMiddleware(AuthenticationBackend):
4041
"""JWT 认证中间件"""
4142

4243
@staticmethod
43-
def auth_exception_handler(conn: HTTPConnection, exc: _AuthenticationError) -> Response:
44+
def auth_exception_handler(conn: HTTPConnection, exc: AuthenticationError) -> Response:
4445
"""
4546
覆盖内部认证错误处理
4647
@@ -50,15 +51,16 @@ def auth_exception_handler(conn: HTTPConnection, exc: _AuthenticationError) -> R
5051
"""
5152
return MsgSpecJSONResponse(content={'code': exc.code, 'msg': exc.msg, 'data': None}, status_code=exc.code)
5253

53-
async def authenticate(self, request: Request) -> tuple[AuthCredentials, GetUserInfoWithRelationDetail] | None:
54+
@staticmethod
55+
def extract_token(request: Request) -> str | None:
5456
"""
55-
认证请求
57+
从请求中提取 Bearer Token
5658
5759
:param request: FastAPI 请求对象
5860
:return:
5961
"""
60-
token = request.headers.get('Authorization')
61-
if not token:
62+
authorization = request.headers.get('Authorization')
63+
if not authorization:
6264
return None
6365

6466
path = request.url.path
@@ -68,17 +70,30 @@ async def authenticate(self, request: Request) -> tuple[AuthCredentials, GetUser
6870
if pattern.match(path):
6971
return None
7072

71-
scheme, token = get_authorization_scheme_param(token)
73+
scheme, token = get_authorization_scheme_param(authorization)
7274
if scheme.lower() != 'bearer':
7375
return None
7476

77+
return token
78+
79+
async def authenticate(self, request: Request) -> tuple[AuthCredentials, GetUserInfoWithRelationDetail] | None:
80+
"""
81+
认证请求
82+
83+
:param request: FastAPI 请求对象
84+
:return:
85+
"""
86+
token = self.extract_token(request)
87+
if token is None:
88+
return None
89+
7590
try:
7691
user = await jwt_authentication(token)
7792
except TokenError as exc:
78-
raise _AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
93+
raise AuthenticationError(code=exc.code, msg=exc.detail, headers=exc.headers)
7994
except Exception as e:
8095
log.exception(f'JWT 授权异常:{e}')
81-
raise _AuthenticationError(code=getattr(e, 'code', 500), msg=getattr(e, 'msg', 'Internal Server Error'))
96+
raise AuthenticationError(code=getattr(e, 'code', 500), msg=getattr(e, 'msg', 'Internal Server Error'))
8297

8398
# 请注意,此返回使用非标准模式,所以在认证通过时,将丢失某些标准特性
8499
# 标准返回模式请查看:https://www.starlette.io/authentication/

0 commit comments

Comments
 (0)