11import logging
22
33from fastapi import HTTPException
4- from sqlmodel import Session , select
4+ from sqlalchemy .ext .asyncio import AsyncSession
5+ from sqlmodel import select
56
67from module .models import ResponseModel
78from module .models .user import User , UserLogin , UserUpdate
1112
1213
1314class UserDatabase :
14- def __init__ (self , session : Session ):
15+ def __init__ (self , session ):
1516 self .session = session
1617
17- def get_user (self , username ):
18+ async def get_user (self , username ):
1819 statement = select (User ).where (User .username == username )
19- result = self .session .exec (statement ).first ()
20- if not result :
20+ if isinstance (self .session , AsyncSession ):
21+ result = await self .session .execute (statement )
22+ user = result .scalar_one_or_none ()
23+ else :
24+ user = self .session .exec (statement ).first ()
25+ if not user :
2126 raise HTTPException (status_code = 404 , detail = "User not found" )
22- return result
27+ return user
2328
24- def auth_user (self , user : User ):
29+ async def auth_user (self , user : User ):
2530 statement = select (User ).where (User .username == user .username )
26- result = self .session .exec (statement ).first ()
31+ if isinstance (self .session , AsyncSession ):
32+ result = await self .session .execute (statement )
33+ db_user = result .scalar_one_or_none ()
34+ else :
35+ db_user = self .session .exec (statement ).first ()
2736 if not user .password :
2837 return ResponseModel (
2938 status_code = 401 , status = False , msg_en = "Incorrect password format" , msg_zh = "密码格式不正确"
3039 )
31- if not result :
40+ if not db_user :
3241 return ResponseModel (
3342 status_code = 401 , status = False , msg_en = "User not found" , msg_zh = "用户不存在"
3443 )
35- if not verify_password (user .password , result .password ):
44+ if not verify_password (user .password , db_user .password ):
3645 return ResponseModel (
3746 status_code = 401 ,
3847 status = False ,
@@ -43,36 +52,59 @@ def auth_user(self, user: User):
4352 status_code = 200 , status = True , msg_en = "Login successfully" , msg_zh = "登录成功"
4453 )
4554
46- def update_user (self , username , update_user : UserUpdate ):
47- # Update username and password
55+ async def update_user (self , username , update_user : UserUpdate ):
4856 statement = select (User ).where (User .username == username )
49- result = self .session .exec (statement ).first ()
50- if not result :
57+ if isinstance (self .session , AsyncSession ):
58+ result = await self .session .execute (statement )
59+ db_user = result .scalar_one_or_none ()
60+ else :
61+ db_user = self .session .exec (statement ).first ()
62+ if not db_user :
5163 raise HTTPException (status_code = 404 , detail = "User not found" )
5264 if update_user .username :
53- result .username = update_user .username
65+ db_user .username = update_user .username
5466 if update_user .password :
55- result .password = get_password_hash (update_user .password )
56- self .session .add (result )
57- self .session .commit ()
58- return result
67+ db_user .password = get_password_hash (update_user .password )
68+ self .session .add (db_user )
69+ if isinstance (self .session , AsyncSession ):
70+ await self .session .commit ()
71+ else :
72+ self .session .commit ()
73+ return db_user
74+
75+ async def add_default_user (self ):
76+ statement = select (User )
77+ if isinstance (self .session , AsyncSession ):
78+ result = await self .session .execute (statement )
79+ users = list (result .scalars ().all ())
80+ else :
81+ try :
82+ users = self .session .exec (statement ).all ()
83+ except Exception :
84+ self .merge_old_user ()
85+ users = self .session .exec (statement ).all ()
86+ if len (users ) != 0 :
87+ return
88+ user = User (username = "admin" , password = get_password_hash ("adminadmin" ))
89+ self .session .add (user )
90+ if isinstance (self .session , AsyncSession ):
91+ await self .session .commit ()
92+ else :
93+ self .session .commit ()
5994
6095 def merge_old_user (self ):
61- # get old data
96+ # Legacy migration - sync only
6297 statement = """
6398 SELECT * FROM user
6499 """
65100 result = self .session .exec (statement ).first ()
66101 if not result :
67102 return
68- # add new data
69103 user = User (username = result .username , password = result .password )
70- # Drop old table
71104 statement = """
72105 DROP TABLE user
73106 """
74107 self .session .exec (statement )
75- # Create new table
76108 statement = """
77109 CREATE TABLE user (
78110 id INTEGER NOT NULL PRIMARY KEY,
@@ -83,18 +115,3 @@ def merge_old_user(self):
83115 self .session .exec (statement )
84116 self .session .add (user )
85117 self .session .commit ()
86-
87- def add_default_user (self ):
88- # Check if user exists
89- statement = select (User )
90- try :
91- result = self .session .exec (statement ).all ()
92- except Exception :
93- self .merge_old_user ()
94- result = self .session .exec (statement ).all ()
95- if len (result ) != 0 :
96- return
97- # Add default user
98- user = User (username = "admin" , password = get_password_hash ("adminadmin" ))
99- self .session .add (user )
100- self .session .commit ()
0 commit comments