1- import time
2- from pathlib import Path
1+ import asyncio
2+ import signal
3+ import sys
34
45import aiohttp
5- import alembic .command
6- import alembic .config
76import fastapi
87import sqlmodel
98from fastapi import Depends , HTTPException
109from fastapi .middleware .cors import CORSMiddleware
1110from loguru import logger
12- from oasst_inference_server import auth , client_handler , deps , models , worker_handler
11+ from oasst_inference_server import auth , client_handler , database , deps , models , worker_handler
1312from oasst_inference_server .schemas import chat as chat_schema
1413from oasst_inference_server .schemas import worker as worker_schema
1514from oasst_inference_server .settings import settings
@@ -67,19 +66,23 @@ def get_root_token(token: str = Depends(get_bearer_token)) -> str:
6766 )
6867
6968
69+ def terminate_server (signum , frame ):
70+ logger .info (f"Signal { signum } . Terminating server..." )
71+ sys .exit (0 )
72+
73+
7074@app .on_event ("startup" )
71- def alembic_upgrade ():
75+ async def alembic_upgrade ():
76+ signal .signal (signal .SIGINT , terminate_server )
7277 if not settings .update_alembic :
7378 logger .info ("Skipping alembic upgrade on startup (update_alembic is False)" )
7479 return
7580 logger .info ("Attempting to upgrade alembic on startup" )
7681 retry = 0
7782 while True :
7883 try :
79- alembic_ini_path = Path (__file__ ).parent / "alembic.ini"
80- alembic_cfg = alembic .config .Config (str (alembic_ini_path ))
81- alembic_cfg .set_main_option ("sqlalchemy.url" , settings .database_uri )
82- alembic .command .upgrade (alembic_cfg , "head" )
84+ async with database .make_engine ().begin () as conn :
85+ await conn .run_sync (database .alembic_upgrade )
8386 logger .info ("Successfully upgraded alembic on startup" )
8487 break
8588 except Exception :
@@ -90,28 +93,26 @@ def alembic_upgrade():
9093
9194 timeout = settings .alembic_retry_timeout * 2 ** retry
9295 logger .warning (f"Retrying alembic upgrade in { timeout } seconds" )
93- time .sleep (timeout )
96+ await asyncio .sleep (timeout )
97+ signal .signal (signal .SIGINT , signal .SIG_DFL )
9498
9599
96100@app .on_event ("startup" )
97- def maybe_add_debug_api_keys ():
101+ async def maybe_add_debug_api_keys ():
98102 if not settings .debug_api_keys :
99103 logger .info ("No debug API keys configured, skipping" )
100104 return
101105 try :
102106 logger .info ("Adding debug API keys" )
103- with deps .manual_create_session () as session :
107+ async with deps .manual_create_session () as session :
104108 for api_key in settings .debug_api_keys :
105109 logger .info (f"Checking if debug API key { api_key } exists" )
106110 if (
107- session .exec (
108- sqlmodel .select (models .DbWorker ).where (models .DbWorker .api_key == api_key )
109- ).one_or_none ()
110- is None
111- ):
111+ await session .exec (sqlmodel .select (models .DbWorker ).where (models .DbWorker .api_key == api_key ))
112+ ).one_or_none () is None :
112113 logger .info (f"Adding debug API key { api_key } " )
113114 session .add (models .DbWorker (api_key = api_key , name = "Debug API Key" ))
114- session .commit ()
115+ await session .commit ()
115116 else :
116117 logger .info (f"Debug API key { api_key } already exists" )
117118 except Exception :
@@ -129,7 +130,7 @@ async def login_discord():
129130@app .get ("/auth/callback/discord" , response_model = protocol .Token )
130131async def callback_discord (
131132 code : str ,
132- db : sqlmodel . Session = Depends (deps .create_session ),
133+ db : database . AsyncSession = Depends (deps .create_session ),
133134):
134135 redirect_uri = f"{ settings .api_root } /auth/callback/discord"
135136
@@ -166,15 +167,15 @@ async def callback_discord(
166167 raise HTTPException (status_code = 400 , detail = "Invalid user info response from Discord" )
167168
168169 # Try to find a user in our DB linked to the Discord user
169- user : models .DbUser = query_user_by_provider_id (db , discord_id = discord_id )
170+ user : models .DbUser = await query_user_by_provider_id (db , discord_id = discord_id )
170171
171172 # Create if no user exists
172173 if not user :
173174 user = models .DbUser (provider = "discord" , provider_account_id = discord_id , display_name = discord_username )
174175
175176 db .add (user )
176- db .commit ()
177- db .refresh (user )
177+ await db .commit ()
178+ await db .refresh (user )
178179
179180 # Discord account is authenticated and linked to a user; create JWT
180181 access_token = auth .create_access_token ({"user_id" : user .id })
@@ -188,7 +189,7 @@ async def list_chats(
188189) -> chat_schema .ListChatsResponse :
189190 """Lists all chats."""
190191 logger .info ("Listing all chats." )
191- chats = ucr .get_chats ()
192+ chats = await ucr .get_chats ()
192193 chats_list = [chat .to_list_read () for chat in chats ]
193194 return chat_schema .ListChatsResponse (chats = chats_list )
194195
@@ -200,7 +201,7 @@ async def create_chat(
200201) -> chat_schema .ChatListRead :
201202 """Allows a client to create a new chat."""
202203 logger .info (f"Received { request = } " )
203- chat = ucr .create_chat ()
204+ chat = await ucr .create_chat ()
204205 return chat .to_list_read ()
205206
206207
@@ -210,7 +211,7 @@ async def get_chat(
210211 ucr : UserChatRepository = Depends (deps .create_user_chat_repository ),
211212) -> chat_schema .ChatRead :
212213 """Allows a client to get the current state of a chat."""
213- chat = ucr .get_chat_by_id (id )
214+ chat = await ucr .get_chat_by_id (id )
214215 return chat .to_read ()
215216
216217
@@ -225,45 +226,45 @@ async def get_chat(
225226
226227
227228@app .put ("/worker" )
228- def create_worker (
229+ async def create_worker (
229230 request : worker_schema .CreateWorkerRequest ,
230231 root_token : str = Depends (get_root_token ),
231- session : sqlmodel . Session = Depends (deps .create_session ),
232- ):
232+ session : database . AsyncSession = Depends (deps .create_session ),
233+ ) -> worker_schema . WorkerRead :
233234 """Allows a client to register a worker."""
234235 worker = models .DbWorker (name = request .name )
235236 session .add (worker )
236- session .commit ()
237- session .refresh (worker )
238- return worker
237+ await session .commit ()
238+ await session .refresh (worker )
239+ return worker_schema . WorkerRead . from_orm ( worker )
239240
240241
241242@app .get ("/worker" )
242- def list_workers (
243+ async def list_workers (
243244 root_token : str = Depends (get_root_token ),
244- session : sqlmodel . Session = Depends (deps .create_session ),
245- ):
245+ session : database . AsyncSession = Depends (deps .create_session ),
246+ ) -> list [ worker_schema . WorkerRead ] :
246247 """Lists all workers."""
247- workers = session .exec (sqlmodel .select (models .DbWorker )).all ()
248- return list ( workers )
248+ workers = ( await session .exec (sqlmodel .select (models .DbWorker ) )).all ()
249+ return [ worker_schema . WorkerRead . from_orm ( worker ) for worker in workers ]
249250
250251
251252@app .delete ("/worker/{worker_id}" )
252- def delete_worker (
253+ async def delete_worker (
253254 worker_id : str ,
254255 root_token : str = Depends (get_root_token ),
255- session : sqlmodel . Session = Depends (deps .create_session ),
256+ session : database . AsyncSession = Depends (deps .create_session ),
256257):
257258 """Deletes a worker."""
258- worker = session .get (models .DbWorker , worker_id )
259+ worker = await session .get (models .DbWorker , worker_id )
259260 session .delete (worker )
260- session .commit ()
261+ await session .commit ()
261262 return fastapi .Response (status_code = 200 )
262263
263264
264- def query_user_by_provider_id (db : sqlmodel . Session , discord_id : str | None = None ) -> models .DbUser | None :
265+ async def query_user_by_provider_id (db : database . AsyncSession , discord_id : str | None = None ) -> models .DbUser | None :
265266 """Returns the user associated with a given provider ID if any."""
266- user_qry = db . query (models .DbUser )
267+ user_qry = sqlmodel . select (models .DbUser )
267268
268269 if discord_id :
269270 user_qry = user_qry .filter (models .DbUser .provider == "discord" ).filter (
@@ -273,12 +274,12 @@ def query_user_by_provider_id(db: sqlmodel.Session, discord_id: str | None = Non
273274 else :
274275 return None
275276
276- user : models .DbUser = user_qry .first ()
277+ user : models .DbUser = ( await db . exec ( user_qry )) .first ()
277278 return user
278279
279280
280281@app .get ("/auth/login/debug" )
281- async def login_debug (username : str , db : sqlmodel . Session = Depends (deps .create_session )):
282+ async def login_debug (username : str , db : database . AsyncSession = Depends (deps .create_session )):
282283 """Login using a debug username, which the system will accept unconditionally."""
283284
284285 if not settings .allow_debug_auth :
@@ -288,14 +289,16 @@ async def login_debug(username: str, db: sqlmodel.Session = Depends(deps.create_
288289 raise HTTPException (status_code = 400 , detail = "Username is required" )
289290
290291 # Try to find the user
291- user : models .DbUser = db .exec (sqlmodel .select (models .DbUser ).where (models .DbUser .id == username )).one_or_none ()
292+ user : models .DbUser = (
293+ await db .exec (sqlmodel .select (models .DbUser ).where (models .DbUser .id == username ))
294+ ).one_or_none ()
292295
293296 if user is None :
294297 logger .info (f"Creating new debug user { username = } " )
295298 user = models .DbUser (id = username , display_name = username , provider = "debug" , provider_account_id = username )
296299 db .add (user )
297- db .commit ()
298- db .refresh (user )
300+ await db .commit ()
301+ await db .refresh (user )
299302
300303 # Discord account is authenticated and linked to a user; create JWT
301304 access_token = auth .create_access_token ({"user_id" : user .id })
0 commit comments