Skip to content

Commit 09ffbf3

Browse files
its asyncing time
1 parent e3a8c43 commit 09ffbf3

File tree

6 files changed

+70
-57
lines changed

6 files changed

+70
-57
lines changed

gs/backend/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ def getenv(config: str) -> str:
4646

4747
DATABASE_CONNECTION_STRING: Final[
4848
str
49-
] = f"postgresql+psycopg2://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}"
49+
] = f"postgresql+asyncpg://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}"

gs/backend/data/data_wrappers/abstract_wrapper.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,53 +18,54 @@ class AbstractWrapper(ABC, Generic[T, PK]):
1818

1919
model: type[T]
2020

21-
def get_all(self) -> list[T]:
21+
async def get_all(self) -> list[T]:
2222
"""
2323
Get all data wrapper for the unspecified model
2424
2525
:return: a list of all model instances
2626
"""
27-
with get_db_session() as session:
28-
return list(session.exec(select(self.model)).all())
27+
async with get_db_session() as session:
28+
result = await session.exec(select(self.model))
29+
return list(result.all())
2930

30-
def get_by_id(self, obj_id: PK) -> T:
31+
async def get_by_id(self, obj_id: PK) -> T:
3132
"""
3233
Retrieve data wrapper for the unspecified model
3334
3435
:param obj_id: PK of the model instance to be retrieved
3536
:return: the retrieved instance
3637
"""
37-
with get_db_session() as session:
38-
obj = session.get(self.model, obj_id)
38+
async with get_db_session() as session:
39+
obj = await session.get(self.model, obj_id)
3940
if not obj:
4041
raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.")
4142
return obj
4243

43-
def create(self, data: dict[str, Any]) -> T:
44+
async def create(self, data: dict[str, Any]) -> T:
4445
"""
4546
Post data wrapper for the unspecified model
4647
4748
:param data: the JSON object of the model instance to be created
4849
:return: the newly created instance
4950
"""
50-
with get_db_session() as session:
51+
async with get_db_session() as session:
5152
obj = self.model(**data)
5253
session.add(obj)
53-
session.commit()
54-
session.refresh(obj)
54+
await session.commit()
55+
await session.refresh(obj)
5556
return obj
5657

57-
def delete_by_id(self, obj_id: PK) -> T:
58+
async def delete_by_id(self, obj_id: PK) -> T:
5859
"""
5960
Delete data wrapper for the unspecified model
6061
6162
:param obj_id: PK of the model instance to be deleted
6263
:return: the deleted instance
6364
"""
64-
with get_db_session() as session:
65-
obj = session.get(self.model, obj_id)
65+
async with get_db_session() as session:
66+
obj = await session.get(self.model, obj_id)
6667
if not obj:
6768
raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.")
6869
session.delete(obj)
69-
session.commit()
70+
await session.commit()
7071
return obj

gs/backend/data/data_wrappers/wrappers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,17 @@ class CommandsWrapper(AbstractWrapper[Commands, UUID]):
101101

102102
model = Commands
103103

104-
def retrieve_floating_commands(self) -> list[Commands]:
104+
async def retrieve_floating_commands(self) -> list[Commands]:
105105
"""
106106
Retrieves all commands which do not have a valid entry in
107107
the packet_commands table.
108108
A command which is not valid is considered as any command whose ID
109109
does not match with any command_id in the packet_commands table
110110
"""
111-
packet_commands = PacketCommandsWrapper().get_all()
111+
packet_commands = await PacketCommandsWrapper().get_all()
112112
packet_ids = {packet_command.command_id for packet_command in packet_commands}
113113

114-
commands = self.get_all()
114+
commands = await self.get_all()
115115
floating_commands = [fc for fc in commands if fc.id not in packet_ids]
116116

117117
return floating_commands

gs/backend/data/database/engine.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
from sqlalchemy import Engine
2-
from sqlmodel import Session, create_engine, text
1+
from sqlalchemy import text
2+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
33

44
from gs.backend.config.config import DATABASE_CONNECTION_STRING
55
from gs.backend.data.tables.aro_user_tables import ARO_USER_SCHEMA_NAME
66
from gs.backend.data.tables.main_tables import MAIN_SCHEMA_NAME
77
from gs.backend.data.tables.transactional_tables import TRANSACTIONAL_SCHEMA_NAME
88

99

10-
def get_db_engine() -> Engine:
10+
def get_db_engine() -> AsyncEngine:
1111
"""
1212
Creates the database engine
1313
1414
:return: engine
1515
"""
16-
return create_engine(DATABASE_CONNECTION_STRING)
16+
return create_async_engine(DATABASE_CONNECTION_STRING)
1717

1818

19-
def get_db_session() -> Session:
19+
async def get_db_session() -> AsyncSession:
2020
"""
2121
Creates the database session.
2222
@@ -25,22 +25,22 @@ def get_db_session() -> Session:
2525
:return: session
2626
"""
2727
engine = get_db_engine()
28-
with Session(engine) as session:
29-
return session
28+
async with AsyncSession(engine) as session:
29+
yield session
3030

3131

32-
def _create_schemas(session: Session) -> None:
32+
async def _create_schemas(session: AsyncSession) -> None:
3333
"""
3434
Creates the schemas in the database.
3535
3636
:param session: The session for which to create the schemas
3737
"""
38-
connection = session.connection()
38+
connection = await session.connection()
3939
schemas = [MAIN_SCHEMA_NAME, TRANSACTIONAL_SCHEMA_NAME, ARO_USER_SCHEMA_NAME]
4040
for schema in schemas:
4141
# sqlalchemy doesn't check if the schema exists before attempting to create one
42-
connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema}"))
43-
connection.commit()
42+
await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema}"))
43+
await connection.commit()
4444

4545

4646
'''Deprecated method to create tables, now handled by Alembic migrations
@@ -58,11 +58,11 @@ def _create_tables(session: Session) -> None:
5858
'''
5959

6060

61-
def setup_database(session: Session) -> None:
61+
async def setup_database(session: AsyncSession) -> None:
6262
"""
6363
Creates the schemas for the session.
6464
Table creation is now handled by Alembic migrations
6565
6666
:param session: The session for which to create the schemas
6767
"""
68-
_create_schemas(session)
68+
await _create_schemas(session)

gs/backend/data/resources/utils.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from sqlmodel import Session, select
1+
from sqlalchemy.ext.asyncio import AsyncSession
2+
from sqlmodel import select
23

34
from gs.backend.data.resources.callsigns import callsigns
45
from gs.backend.data.resources.main_commands import main_commands
@@ -7,34 +8,34 @@
78
from gs.backend.data.tables.main_tables import MainCommand, MainTelemetry
89

910

10-
def add_main_commands(session: Session) -> None:
11+
async def add_main_commands(session: AsyncSession) -> None:
1112
"""
1213
Setup the main commands to the database
1314
"""
1415
query = select(MainCommand).limit(1) # Check if the db is empty
15-
result = session.exec(query).first()
16-
if not result:
16+
result = await session.exec(query)
17+
if not result.first():
1718
session.add_all(main_commands())
18-
session.commit()
19+
await session.commit()
1920

2021

21-
def add_callsigns(session: Session) -> None:
22+
async def add_callsigns(session: AsyncSession) -> None:
2223
"""
2324
Setup the valid callsigns to the database
2425
"""
2526
query = select(AROUserCallsigns).limit(1)
26-
result = session.exec(query).first()
27-
if not result:
27+
result = await session.exec(query)
28+
if not result.first():
2829
session.add_all(callsigns())
29-
session.commit()
30+
await session.commit()
3031

3132

32-
def add_telemetry(session: Session) -> None:
33+
async def add_telemetry(session: AsyncSession) -> None:
3334
"""
3435
Setup the main telemetry to the database
3536
"""
3637
query = select(MainTelemetry).limit(1) # Check if the db is empty
37-
result = session.exec(query).first()
38-
if not result:
38+
result = await session.exec(query)
39+
if not result.first():
3940
session.add_all(main_telemetry())
40-
session.commit()
41+
await session.commit()

gs/backend/migrate.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import sys
23

34
from gs.backend.data.database.engine import get_db_session
@@ -13,26 +14,36 @@
1314
individually.
1415
"""
1516

16-
if __name__ == "__main__":
17+
18+
async def main() -> None:
19+
"""Main async function to run migrations"""
1720
if len(sys.argv) > 2:
1821
raise ValueError(f"Invalid input. Expected at most 1 argument, received {len(sys.argv)}")
1922
elif len(sys.argv[1:]) == 0:
20-
print("Migrating callsign data...")
21-
add_callsigns(get_db_session())
22-
print("Migrating main command data...")
23-
add_main_commands(get_db_session())
24-
print("Migrating telemetry data...")
25-
add_telemetry(get_db_session())
23+
async with get_db_session() as session:
24+
print("Migrating callsign data...")
25+
await add_callsigns(session)
26+
print("Migrating main command data...")
27+
await add_main_commands(session)
28+
print("Migrating telemetry data...")
29+
await add_telemetry(session)
2630
else:
2731
match sys.argv[1]:
2832
case "callsigns":
29-
print("Migrating callsign data...")
30-
add_callsigns(get_db_session())
33+
async with get_db_session() as session:
34+
print("Migrating callsign data...")
35+
await add_callsigns(session)
3136
case "commands":
32-
print("Migrating main command data...")
33-
add_main_commands(get_db_session())
37+
async with get_db_session() as session:
38+
print("Migrating main command data...")
39+
await add_main_commands(session)
3440
case "telemetries":
35-
print("Migrating telemetry data...")
36-
add_telemetry(get_db_session())
41+
async with get_db_session() as session:
42+
print("Migrating telemetry data...")
43+
await add_telemetry(session)
3744
case _:
3845
raise ValueError("Invalid input. Optional arguments include 'callsigns', 'commands', or 'telemetries'.")
46+
47+
48+
if __name__ == "__main__":
49+
asyncio.run(main())

0 commit comments

Comments
 (0)