diff --git a/gs/backend/api/lifespan.py b/gs/backend/api/lifespan.py index 6f6c2b05a..1b3db144e 100644 --- a/gs/backend/api/lifespan.py +++ b/gs/backend/api/lifespan.py @@ -16,5 +16,6 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]: # Must all the get_db_session each time when pass it into a separate function. # Otherwise, will get transaction is inactive error - setup_database(get_db_session()) + async with get_db_session() as session: + await setup_database(session) yield diff --git a/gs/backend/api/v1/aro/endpoints/user.py b/gs/backend/api/v1/aro/endpoints/user.py index ef03a821b..a29d4d69e 100644 --- a/gs/backend/api/v1/aro/endpoints/user.py +++ b/gs/backend/api/v1/aro/endpoints/user.py @@ -16,31 +16,31 @@ async def get_all_users() -> AllUsersResponse: :return: all users """ - users = AROUsersWrapper().get_all() + users = await AROUsersWrapper().get_all() return AllUsersResponse(data=users) @aro_user_router.get("/get_user/{userid}", response_model=UserResponse) -def get_user(userid: str) -> UserResponse: +async def get_user(userid: str) -> UserResponse: """ Gets a user by ID :param userid: The unique identifier of the user :return: the user """ - user = AROUsersWrapper().get_by_id(UUID(userid)) + user = await AROUsersWrapper().get_by_id(UUID(userid)) return UserResponse(data=user) @aro_user_router.post("/create_user", response_model=UserResponse) -def create_user(payload: UserRequest) -> UserResponse: +async def create_user(payload: UserRequest) -> UserResponse: """ Creates a user with the given payload :param payload: The data used to create a user :return: returns the user created """ - user = AROUsersWrapper().create( + user = await AROUsersWrapper().create( data={ "call_sign": payload.call_sign, "email": payload.email, @@ -54,11 +54,11 @@ def create_user(payload: UserRequest) -> UserResponse: @aro_user_router.delete("/delete_user/{userid}", response_model=UserResponse) -def delete_user(userid: str) -> UserResponse: +async def delete_user(userid: str) -> UserResponse: """ Deletes a user based on the user ID :param userid: The unique identifier of the user to be deleted :return: returns the deleted user """ - deleted_user = AROUsersWrapper().delete_by_id(UUID(userid)) + deleted_user = await AROUsersWrapper().delete_by_id(UUID(userid)) return UserResponse(data=deleted_user) diff --git a/gs/backend/api/v1/mcc/endpoints/commands.py b/gs/backend/api/v1/mcc/endpoints/commands.py index 4ea6fbc0c..507d90d7a 100644 --- a/gs/backend/api/v1/mcc/endpoints/commands.py +++ b/gs/backend/api/v1/mcc/endpoints/commands.py @@ -20,7 +20,7 @@ async def create_command(payload: dict[str, Any]) -> Commands: :param payload: The data used to create a command :return: returns the created command object """ - return create_commands(payload) + return await create_commands(payload) @commands_router.delete("/delete/{command_id}") @@ -31,5 +31,5 @@ async def delete_command(command_id: UUID) -> dict[str, Any]: :param command_id: The id which is to be deleted. :return: returns a dict giving confirmation that command with id of command_id has been deleted. """ - delete_commands_by_id(command_id) + await delete_commands_by_id(command_id) return {"message": f"Command with id {command_id} deleted successfully"} diff --git a/gs/backend/api/v1/mcc/endpoints/main_commands.py b/gs/backend/api/v1/mcc/endpoints/main_commands.py index 9b6edd3a9..b4a358f28 100644 --- a/gs/backend/api/v1/mcc/endpoints/main_commands.py +++ b/gs/backend/api/v1/mcc/endpoints/main_commands.py @@ -15,5 +15,5 @@ async def get_main_commands() -> MainCommandsResponse: :return: list of all commands """ - items = get_all_main_commands() + items = await get_all_main_commands() return MainCommandsResponse(data=items) diff --git a/gs/backend/config/config.py b/gs/backend/config/config.py index 49c5bb279..736d520f9 100644 --- a/gs/backend/config/config.py +++ b/gs/backend/config/config.py @@ -46,4 +46,4 @@ def getenv(config: str) -> str: DATABASE_CONNECTION_STRING: Final[ str -] = f"postgresql+psycopg2://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}" +] = f"postgresql+asyncpg://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}" diff --git a/gs/backend/data/data_wrappers/abstract_wrapper.py b/gs/backend/data/data_wrappers/abstract_wrapper.py index af2c32ab4..2ae3d3afb 100644 --- a/gs/backend/data/data_wrappers/abstract_wrapper.py +++ b/gs/backend/data/data_wrappers/abstract_wrapper.py @@ -18,53 +18,56 @@ class AbstractWrapper(ABC, Generic[T, PK]): model: type[T] - def get_all(self) -> list[T]: + async def get_all(self) -> list[T]: """ Get all data wrapper for the unspecified model :return: a list of all model instances """ - with get_db_session() as session: - return list(session.exec(select(self.model)).all()) + async with get_db_session() as session: + result = await session.execute(select(self.model)) + return list(result.scalars().all()) - def get_by_id(self, obj_id: PK) -> T: + async def get_by_id(self, obj_id: PK) -> T: """ Retrieve data wrapper for the unspecified model :param obj_id: PK of the model instance to be retrieved :return: the retrieved instance """ - with get_db_session() as session: - obj = session.get(self.model, obj_id) + async with get_db_session() as session: + obj = await session.get(self.model, obj_id) if not obj: raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.") return obj - def create(self, data: dict[str, Any]) -> T: + async def create(self, data: dict[str, Any]) -> T: """ Post data wrapper for the unspecified model :param data: the JSON object of the model instance to be created :return: the newly created instance """ - with get_db_session() as session: + async with get_db_session() as session: obj = self.model(**data) session.add(obj) - session.commit() - session.refresh(obj) + await session.commit() + await session.refresh(obj) return obj - def delete_by_id(self, obj_id: PK) -> T: + async def delete_by_id(self, obj_id: PK) -> T: """ Delete data wrapper for the unspecified model :param obj_id: PK of the model instance to be deleted :return: the deleted instance """ - with get_db_session() as session: - obj = session.get(self.model, obj_id) + async with get_db_session() as session: + obj = await session.get(self.model, obj_id) if not obj: raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.") - session.delete(obj) - session.commit() - return obj + # Preserve object state before deleting + obj_copy = self.model(**{c.name: getattr(obj, c.name) for c in obj.__table__.columns}) # type: ignore[attr-defined] + session.delete(obj) # type: ignore[unused-coroutine] + await session.commit() + return obj_copy diff --git a/gs/backend/data/data_wrappers/aro_wrapper/aro_request_wrapper.py b/gs/backend/data/data_wrappers/aro_wrapper/aro_request_wrapper.py index 113a9e21f..ad1be0743 100644 --- a/gs/backend/data/data_wrappers/aro_wrapper/aro_request_wrapper.py +++ b/gs/backend/data/data_wrappers/aro_wrapper/aro_request_wrapper.py @@ -9,16 +9,17 @@ from gs.backend.data.tables.transactional_tables import ARORequest -def get_all_requests() -> list[ARORequest]: +async def get_all_requests() -> list[ARORequest]: """ Get all the requests from aro """ - with get_db_session() as session: - requests = list(session.exec(select(ARORequest)).all()) + async with get_db_session() as session: + result = await session.execute(select(ARORequest)) + requests = list(result.scalars().all()) return requests -def add_request( +async def add_request( aro_id: UUID, long: Decimal, lat: Decimal, @@ -38,7 +39,7 @@ def add_request( :param taken_date: datetime object representing the date that this picture was trasmitted :param status: the status of the request, can only be from the requets in ARORequestStatus """ - with get_db_session() as session: + async with get_db_session() as session: request = ARORequest( aro_id=aro_id, latitude=lat, @@ -51,23 +52,23 @@ def add_request( ) session.add(request) - session.commit() - session.refresh(request) + await session.commit() + await session.refresh(request) return request -def delete_request_by_id(request_id: str) -> list[ARORequest]: +async def delete_request_by_id(request_id: str) -> list[ARORequest]: """ Delete a request based on id :param request_id: unique identifier of the request """ - with get_db_session() as session: - request = session.get(ARORequest, request_id) + async with get_db_session() as session: + request = await session.get(ARORequest, request_id) if request: - session.delete(request) - session.commit() + session.delete(request) # type: ignore[unused-coroutine] + await session.commit() else: raise ValueError("Request not found, ID does not exist") - return get_all_requests() + return await get_all_requests() diff --git a/gs/backend/data/data_wrappers/aro_wrapper/aro_user_auth_token_wrapper.py b/gs/backend/data/data_wrappers/aro_wrapper/aro_user_auth_token_wrapper.py index 0f0d52324..8a8d73a3c 100644 --- a/gs/backend/data/data_wrappers/aro_wrapper/aro_user_auth_token_wrapper.py +++ b/gs/backend/data/data_wrappers/aro_wrapper/aro_user_auth_token_wrapper.py @@ -8,16 +8,17 @@ from gs.backend.data.tables.aro_user_tables import AROUserAuthToken -def get_all_auth_tokens() -> list[AROUserAuthToken]: +async def get_all_auth_tokens() -> list[AROUserAuthToken]: """ Get all the auth tokens """ - with get_db_session() as session: - auth_tokens = list(session.exec(select(AROUserAuthToken)).all()) + async with get_db_session() as session: + result = await session.execute(select(AROUserAuthToken)) + auth_tokens = list(result.scalars().all()) return auth_tokens -def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: AROEnums) -> AROUserAuthToken: +async def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: AROEnums) -> AROUserAuthToken: """ Add auth token to the db @@ -26,29 +27,29 @@ def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: :param expiry: the date in which this token expires :param auth_type: the type of auth token this is, can only be from AROAuthToken """ - with get_db_session() as session: + async with get_db_session() as session: auth_token = AROUserAuthToken(token=token, user_data_id=user_data_id, expiry=expiry, auth_type=auth_type) session.add(auth_token) - session.commit() - session.refresh(auth_token) + await session.commit() + await session.refresh(auth_token) return auth_token -def delete_auth_token_by_id(token_id: UUID) -> list[AROUserAuthToken]: +async def delete_auth_token_by_id(token_id: UUID) -> list[AROUserAuthToken]: """ Delete the auth token based on the token id :param token_id: the unique identifier for a particular auth token """ - with get_db_session() as session: - auth_token = session.get(AROUserAuthToken, token_id) + async with get_db_session() as session: + auth_token = await session.get(AROUserAuthToken, token_id) if auth_token: - session.delete(auth_token) - session.commit() + session.delete(auth_token) # type: ignore[unused-coroutine] + await session.commit() else: print("Token does not exist") - return get_all_auth_tokens() + return await get_all_auth_tokens() diff --git a/gs/backend/data/data_wrappers/aro_wrapper/aro_user_data_wrapper.py b/gs/backend/data/data_wrappers/aro_wrapper/aro_user_data_wrapper.py index cebce12c7..025e94975 100644 --- a/gs/backend/data/data_wrappers/aro_wrapper/aro_user_data_wrapper.py +++ b/gs/backend/data/data_wrappers/aro_wrapper/aro_user_data_wrapper.py @@ -7,18 +7,19 @@ # selects all objects of type AROUser from db and returns them in list -def get_all_users() -> list[AROUsers]: +async def get_all_users() -> list[AROUsers]: """ Gets all user """ - with get_db_session() as session: - users = list(session.exec(select(AROUsers)).all()) + async with get_db_session() as session: + result = await session.execute(select(AROUsers)) + users = list(result.scalars().all()) return users # adds user to database of type AROUser then fetches the user from database # so that the user now has an assigned ID -def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: str) -> AROUsers: +async def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: str) -> AROUsers: """ Add a new user to the AROUser table in database @@ -28,9 +29,10 @@ def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: :param l_name: last name of user :param phone_numer: phone number of user """ - with get_db_session() as session: + async with get_db_session() as session: # check if the user already exists with email as it is unique - existing_user = session.exec(select(AROUsers).where(AROUsers.email == email)).first() + result = await session.execute(select(AROUsers).where(AROUsers.email == email)) + existing_user = result.scalars().first() if existing_user: raise ValueError("User already exsits based on email") @@ -40,13 +42,13 @@ def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: ) session.add(user) - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) return user # updates user into database of type AROUser then fetches the user from database -def update_user_by_id( +async def update_user_by_id( userid: UUID, call_sign: str, email: str, f_name: str, l_name: str, phone_number: str ) -> AROUsers: """ @@ -59,9 +61,10 @@ def update_user_by_id( :param l_name: last name of user :param phone_numer: phone number of user """ - with get_db_session() as session: + async with get_db_session() as session: # check if the user already exists with email as it is unique - user = session.exec(select(AROUsers).where(AROUsers.id == userid)).first() + result = await session.execute(select(AROUsers).where(AROUsers.id == userid)) + user = result.scalars().first() if not user: raise ValueError("User does not exist based on user ID") @@ -73,25 +76,25 @@ def update_user_by_id( user.phone_number = phone_number session.add(user) - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) return user # deletes the user with given id and returns the remaining users -def delete_user_by_id(userid: UUID) -> list[AROUsers]: +async def delete_user_by_id(userid: UUID) -> list[AROUsers]: """ Use the user.id to delete a user from table :param userid: identifier unique to the user """ - with get_db_session() as session: - user = session.get(AROUsers, userid) + async with get_db_session() as session: + user = await session.get(AROUsers, userid) if user: - session.delete(user) - session.commit() + session.delete(user) # type: ignore[unused-coroutine] + await session.commit() else: raise ValueError("User ID does not exist") - return get_all_users() + return await get_all_users() diff --git a/gs/backend/data/data_wrappers/aro_wrapper/aro_user_login_wrapper.py b/gs/backend/data/data_wrappers/aro_wrapper/aro_user_login_wrapper.py index 6b5c26fbb..f65b85b21 100644 --- a/gs/backend/data/data_wrappers/aro_wrapper/aro_user_login_wrapper.py +++ b/gs/backend/data/data_wrappers/aro_wrapper/aro_user_login_wrapper.py @@ -6,16 +6,19 @@ from gs.backend.data.tables.aro_user_tables import AROUserLogin -def get_all_logins() -> list[AROUserLogin]: +async def get_all_logins() -> list[AROUserLogin]: """ Gets all the logins """ - with get_db_session() as session: - user_logins = list(session.exec(select(AROUserLogin)).all()) + async with get_db_session() as session: + result = await session.execute(select(AROUserLogin)) + user_logins = list(result.scalars().all()) return user_logins -def add_login(email: str, pwd: str, hash_algo: str, user_data_id: UUID, email_verification_token: str) -> AROUserLogin: +async def add_login( + email: str, pwd: str, hash_algo: str, user_data_id: UUID, email_verification_token: str +) -> AROUserLogin: """ Add a new user login @@ -25,9 +28,10 @@ def add_login(email: str, pwd: str, hash_algo: str, user_data_id: UUID, email_ve :user_data_id: the unique identifier which binds this login to the user which created it :email_verification_token: email verification token """ - with get_db_session() as session: + async with get_db_session() as session: # check if the user exists already - existing_login = session.exec(select(AROUserLogin).where(AROUserLogin.email == email)).first() + result = await session.execute(select(AROUserLogin).where(AROUserLogin.email == email)) + existing_login = result.scalars().first() if existing_login: raise ValueError("User login already exists based on email") @@ -41,23 +45,23 @@ def add_login(email: str, pwd: str, hash_algo: str, user_data_id: UUID, email_ve ) session.add(user_login) - session.commit() - session.refresh(user_login) + await session.commit() + await session.refresh(user_login) return user_login -def delete_login_by_id(loginid: UUID) -> list[AROUserLogin]: +async def delete_login_by_id(loginid: UUID) -> list[AROUserLogin]: """ Use the .id to delete a user from table :param loginid: unique identifier of the target login """ - with get_db_session() as session: - user_login = session.get(AROUserLogin, loginid) + async with get_db_session() as session: + user_login = await session.get(AROUserLogin, loginid) if user_login: - session.delete(user_login) - session.commit() + session.delete(user_login) # type: ignore[unused-coroutine] + await session.commit() else: raise ValueError("Login ID does not exist") - return get_all_logins() + return await get_all_logins() diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/commands_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/commands_wrapper.py index 9ea20a7a9..a51064fed 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/commands_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/commands_wrapper.py @@ -7,43 +7,44 @@ from gs.backend.data.tables.transactional_tables import Commands -def get_all_commands() -> list[Commands]: +async def get_all_commands() -> list[Commands]: """ Get all data wrapper for Commands :return: a list of all commands """ - with get_db_session() as session: - commands = list(session.exec(select(Commands)).all()) + async with get_db_session() as session: + result = await session.execute(select(Commands)) + commands = list(result.scalars().all()) return commands -def create_commands(command_data: dict[str, Any]) -> Commands: +async def create_commands(command_data: dict[str, Any]) -> Commands: """ Post data wrapper for Commands :param command_data: the JSON object of the command to be created :return: the newly created command """ - with get_db_session() as session: + async with get_db_session() as session: command = Commands(**command_data) session.add(command) - session.commit() - session.refresh(command) + await session.commit() + await session.refresh(command) return command -def delete_commands_by_id(command_id: UUID) -> Commands: +async def delete_commands_by_id(command_id: UUID) -> Commands: """ Delete data wrapper for Commands :param command_id: UUID of command to be deleted :return: the deleted command """ - with get_db_session() as session: - command = session.get(Commands, command_id) + async with get_db_session() as session: + command = await session.get(Commands, command_id) if not command: raise ValueError("Command not found.") - session.delete(command) - session.commit() + session.delete(command) # type: ignore[unused-coroutine] + await session.commit() return command diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/comms_session_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/comms_session_wrapper.py index 069aad3b6..1d1bc81ec 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/comms_session_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/comms_session_wrapper.py @@ -7,43 +7,44 @@ from gs.backend.data.tables.transactional_tables import CommsSession -def get_all_comms_sessions() -> list[CommsSession]: +async def get_all_comms_sessions() -> list[CommsSession]: """ Get all data wrapper for CommsSession :return: a list of all sessions """ - with get_db_session() as session: - sessions = list(session.exec(select(CommsSession)).all()) + async with get_db_session() as session: + result = await session.execute(select(CommsSession)) + sessions = list(result.scalars().all()) return sessions -def create_comms_session(session_data: dict[str, Any]) -> CommsSession: +async def create_comms_session(session_data: dict[str, Any]) -> CommsSession: """ Post data wrapper for CommsSession :param session_data: the JSON object of the comms_session to be created :return: the newly created comms_session """ - with get_db_session() as session: + async with get_db_session() as session: comms_session = CommsSession(**session_data) session.add(comms_session) - session.commit() - session.refresh(comms_session) + await session.commit() + await session.refresh(comms_session) return comms_session -def delete_telemetry_by_id(session_id: UUID) -> CommsSession: +async def delete_telemetry_by_id(session_id: UUID) -> CommsSession: """ Delete data wrapper for CommsSession :param session_id: UUID of session to be deleted :return: the deleted session """ - with get_db_session() as session: - comms_session = session.get(CommsSession, session_id) + async with get_db_session() as session: + comms_session = await session.get(CommsSession, session_id) if not comms_session: raise ValueError("Comms session not found.") - session.delete(comms_session) - session.commit() + session.delete(comms_session) # type: ignore[unused-coroutine] + await session.commit() return comms_session diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/main_command_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/main_command_wrapper.py index 18352268c..7e328db71 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/main_command_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/main_command_wrapper.py @@ -6,43 +6,44 @@ from gs.backend.data.tables.main_tables import MainCommand -def get_all_main_commands() -> list[MainCommand]: +async def get_all_main_commands() -> list[MainCommand]: """ Get all data wrapper for MainCommand :return: a list of all main_commands """ - with get_db_session() as session: - commands = list(session.exec(select(MainCommand)).all()) + async with get_db_session() as session: + result = await session.execute(select(MainCommand)) + commands = list(result.scalars().all()) return commands -def create_main_command(command_data: dict[str, Any]) -> MainCommand: +async def create_main_command(command_data: dict[str, Any]) -> MainCommand: """ Post data wrapper for MainCommand :param command_data: the JSON object of the main_command to be created :return: the newly created main_command """ - with get_db_session() as session: + async with get_db_session() as session: command = MainCommand(**command_data) session.add(command) - session.commit() - session.refresh(command) + await session.commit() + await session.refresh(command) return command -def delete_main_command_by_id(command_id: int) -> MainCommand: +async def delete_main_command_by_id(command_id: int) -> MainCommand: """ Delete data wrapper for MainCommand :param command_id: id of main_command to be deleted :return: the deleted main_command """ - with get_db_session() as session: - command = session.get(MainCommand, command_id) + async with get_db_session() as session: + command = await session.get(MainCommand, command_id) if not command: raise ValueError("Main command not found.") - session.delete(command) - session.commit() + session.delete(command) # type: ignore[unused-coroutine] + await session.commit() return command diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/main_telemetry_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/main_telemetry_wrapper.py index d93e507ad..204c0c5e7 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/main_telemetry_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/main_telemetry_wrapper.py @@ -6,43 +6,44 @@ from gs.backend.data.tables.main_tables import MainTelemetry -def get_all_main_telemetries() -> list[MainTelemetry]: +async def get_all_main_telemetries() -> list[MainTelemetry]: """ Get all data wrapper for MainTelemetry :return: a list of all main_telemetries """ - with get_db_session() as session: - telemetries = list(session.exec(select(MainTelemetry)).all()) + async with get_db_session() as session: + result = await session.execute(select(MainTelemetry)) + telemetries = list(result.scalars().all()) return telemetries -def create_main_telemetry(telemetry_data: dict[str, Any]) -> MainTelemetry: +async def create_main_telemetry(telemetry_data: dict[str, Any]) -> MainTelemetry: """ Post data wrapper for MainTelemetry :param command_data: the JSON object of the main_telemetry to be created :return: the newly created main_telemetry """ - with get_db_session() as session: + async with get_db_session() as session: telemetry = MainTelemetry(**telemetry_data) session.add(telemetry) - session.commit() - session.refresh(telemetry) + await session.commit() + await session.refresh(telemetry) return telemetry -def delete_main_telemetry_by_id(telemetry_id: int) -> MainTelemetry: +async def delete_main_telemetry_by_id(telemetry_id: int) -> MainTelemetry: """ Delete data wrapper for MainTelemetry :param command_id: id of main_telemetry to be deleted :return: the deleted main_telemetry """ - with get_db_session() as session: - telemetry = session.get(MainTelemetry, telemetry_id) + async with get_db_session() as session: + telemetry = await session.get(MainTelemetry, telemetry_id) if not telemetry: raise ValueError("Main telemetry not found.") - session.delete(telemetry) - session.commit() + session.delete(telemetry) # type: ignore[unused-coroutine] + await session.commit() return telemetry diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/packet_commands_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/packet_commands_wrapper.py index 260916497..eea86eacc 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/packet_commands_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/packet_commands_wrapper.py @@ -7,43 +7,44 @@ from gs.backend.data.tables.transactional_tables import PacketCommands -def get_all_packet_commands() -> list[PacketCommands]: +async def get_all_packet_commands() -> list[PacketCommands]: """ Get all data wrapper for PacketCommands :return: a list of all packet_commands """ - with get_db_session() as session: - commands = list(session.exec(select(PacketCommands)).all()) + async with get_db_session() as session: + result = await session.execute(select(PacketCommands)) + commands = list(result.scalars().all()) return commands -def create_packet_command(command_data: dict[str, Any]) -> PacketCommands: +async def create_packet_command(command_data: dict[str, Any]) -> PacketCommands: """ Post data wrapper for PacketCommands :param command_data: the JSON object of the packet_command to be created :return: the newly created packet_command """ - with get_db_session() as session: + async with get_db_session() as session: command = PacketCommands(**command_data) session.add(command) - session.commit() - session.refresh(command) + await session.commit() + await session.refresh(command) return command -def delete_packet_command_by_id(command_id: UUID) -> PacketCommands: +async def delete_packet_command_by_id(command_id: UUID) -> PacketCommands: """ Delete data wrapper for PacketCommands :param command_id: UUID of packet_command to be deleted :return: the deleted packet_command """ - with get_db_session() as session: - command = session.get(PacketCommands, command_id) + async with get_db_session() as session: + command = await session.get(PacketCommands, command_id) if not command: raise ValueError("Packet command not found.") - session.delete(command) - session.commit() + session.delete(command) # type: ignore[unused-coroutine] + await session.commit() return command diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/packet_telemetry_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/packet_telemetry_wrapper.py index 31c2dae11..39af43648 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/packet_telemetry_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/packet_telemetry_wrapper.py @@ -7,43 +7,44 @@ from gs.backend.data.tables.transactional_tables import PacketTelemetry -def get_all_packet_telemetries() -> list[PacketTelemetry]: +async def get_all_packet_telemetries() -> list[PacketTelemetry]: """ Get all data wrapper for PacketTelemetry :return: a list of all packet_telemetries """ - with get_db_session() as session: - telemetries = list(session.exec(select(PacketTelemetry)).all()) + async with get_db_session() as session: + result = await session.execute(select(PacketTelemetry)) + telemetries = list(result.scalars().all()) return telemetries -def create_packet_telemetry(telemetry_data: dict[str, Any]) -> PacketTelemetry: +async def create_packet_telemetry(telemetry_data: dict[str, Any]) -> PacketTelemetry: """ Post data wrapper for PacketTelemetry :param command_data: the JSON object of the packet_telemetry to be created :return: the newly created packet_telemetry """ - with get_db_session() as session: + async with get_db_session() as session: telemetry = PacketTelemetry(**telemetry_data) session.add(telemetry) - session.commit() - session.refresh(telemetry) + await session.commit() + await session.refresh(telemetry) return telemetry -def delete_packet_telemetry_by_id(telemetry_id: UUID) -> PacketTelemetry: +async def delete_packet_telemetry_by_id(telemetry_id: UUID) -> PacketTelemetry: """ Delete data wrapper for PacketTelemetry :param command_id: UUID of packet_telemetry to be deleted :return: the deleted packet_telemetry """ - with get_db_session() as session: - telemetry = session.get(PacketTelemetry, telemetry_id) + async with get_db_session() as session: + telemetry = await session.get(PacketTelemetry, telemetry_id) if not telemetry: raise ValueError("Packet telemetry not found.") - session.delete(telemetry) - session.commit() + session.delete(telemetry) # type: ignore[unused-coroutine] + await session.commit() return telemetry diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/packet_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/packet_wrapper.py index c44029d62..db62f366a 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/packet_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/packet_wrapper.py @@ -7,43 +7,44 @@ from gs.backend.data.tables.transactional_tables import Packet -def get_all_packets() -> list[Packet]: +async def get_all_packets() -> list[Packet]: """ Get all data wrapper for Packet :return: a list of all packets """ - with get_db_session() as session: - packets = list(session.exec(select(Packet)).all()) + async with get_db_session() as session: + result = await session.execute(select(Packet)) + packets = list(result.scalars().all()) return packets -def create_packet(packet_data: dict[str, Any]) -> Packet: +async def create_packet(packet_data: dict[str, Any]) -> Packet: """ Post data wrapper for Packet :param packet_data: the JSON object of the packet to be created :return: the newly created packet """ - with get_db_session() as session: + async with get_db_session() as session: packet = Packet(**packet_data) session.add(packet) - session.commit() - session.refresh(packet) + await session.commit() + await session.refresh(packet) return packet -def delete_packet_by_id(packet_id: UUID) -> Packet: +async def delete_packet_by_id(packet_id: UUID) -> Packet: """ Delete data wrapper for Packet :param packet_id: UUID of packet to be deleted :return: the deleted packet """ - with get_db_session() as session: - packet = session.get(Packet, packet_id) + async with get_db_session() as session: + packet = await session.get(Packet, packet_id) if not packet: raise ValueError("Packet not found.") - session.delete(packet) - session.commit() + session.delete(packet) # type: ignore[unused-coroutine] + await session.commit() return packet diff --git a/gs/backend/data/data_wrappers/mcc_wrappers/telemetry_wrapper.py b/gs/backend/data/data_wrappers/mcc_wrappers/telemetry_wrapper.py index 49066c1c2..f9255cc7f 100644 --- a/gs/backend/data/data_wrappers/mcc_wrappers/telemetry_wrapper.py +++ b/gs/backend/data/data_wrappers/mcc_wrappers/telemetry_wrapper.py @@ -7,43 +7,44 @@ from gs.backend.data.tables.transactional_tables import Telemetry -def get_all_telemetries() -> list[Telemetry]: +async def get_all_telemetries() -> list[Telemetry]: """ Get all data wrapper for Telemetry :return: a list of all telemetries """ - with get_db_session() as session: - telemetries = list(session.exec(select(Telemetry)).all()) + async with get_db_session() as session: + result = await session.execute(select(Telemetry)) + telemetries = list(result.scalars().all()) return telemetries -def create_telemetry(telemetry_data: dict[str, Any]) -> Telemetry: +async def create_telemetry(telemetry_data: dict[str, Any]) -> Telemetry: """ Post data wrapper for Telemetry :param telemetry_data: the JSON object of the telemetry to be created :return: the newly created telemetry """ - with get_db_session() as session: + async with get_db_session() as session: telemetry = Telemetry(**telemetry_data) session.add(telemetry) - session.commit() - session.refresh(telemetry) + await session.commit() + await session.refresh(telemetry) return telemetry -def delete_telemetry_by_id(telemetry_id: UUID) -> Telemetry: +async def delete_telemetry_by_id(telemetry_id: UUID) -> Telemetry: """ Delete data wrapper for Telemetry :param telemetry_id: UUID of telemetry to be deleted :return: the deleted telemetry """ - with get_db_session() as session: - telemetry = session.get(Telemetry, telemetry_id) + async with get_db_session() as session: + telemetry = await session.get(Telemetry, telemetry_id) if not telemetry: raise ValueError("Telemetry not found.") - session.delete(telemetry) - session.commit() + session.delete(telemetry) # type: ignore[unused-coroutine] + await session.commit() return telemetry diff --git a/gs/backend/data/data_wrappers/wrappers.py b/gs/backend/data/data_wrappers/wrappers.py index 4557619c4..321cce13c 100644 --- a/gs/backend/data/data_wrappers/wrappers.py +++ b/gs/backend/data/data_wrappers/wrappers.py @@ -101,17 +101,17 @@ class CommandsWrapper(AbstractWrapper[Commands, UUID]): model = Commands - def retrieve_floating_commands(self) -> list[Commands]: + async def retrieve_floating_commands(self) -> list[Commands]: """ Retrieves all commands which do not have a valid entry in the packet_commands table. A command which is not valid is considered as any command whose ID does not match with any command_id in the packet_commands table """ - packet_commands = PacketCommandsWrapper().get_all() + packet_commands = await PacketCommandsWrapper().get_all() packet_ids = {packet_command.command_id for packet_command in packet_commands} - commands = self.get_all() + commands = await self.get_all() floating_commands = [fc for fc in commands if fc.id not in packet_ids] return floating_commands diff --git a/gs/backend/data/database/engine.py b/gs/backend/data/database/engine.py index 454d511ee..cb18ba0f1 100644 --- a/gs/backend/data/database/engine.py +++ b/gs/backend/data/database/engine.py @@ -1,5 +1,8 @@ -from sqlalchemy import Engine -from sqlmodel import Session, create_engine, text +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from gs.backend.config.config import DATABASE_CONNECTION_STRING from gs.backend.data.tables.aro_user_tables import ARO_USER_SCHEMA_NAME @@ -7,16 +10,17 @@ from gs.backend.data.tables.transactional_tables import TRANSACTIONAL_SCHEMA_NAME -def get_db_engine() -> Engine: +def get_db_engine() -> AsyncEngine: """ Creates the database engine :return: engine """ - return create_engine(DATABASE_CONNECTION_STRING) + return create_async_engine(DATABASE_CONNECTION_STRING) -def get_db_session() -> Session: +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """ Creates the database session. @@ -25,22 +29,22 @@ def get_db_session() -> Session: :return: session """ engine = get_db_engine() - with Session(engine) as session: - return session + async with AsyncSession(engine) as session: + yield session -def _create_schemas(session: Session) -> None: +async def _create_schemas(session: AsyncSession) -> None: """ Creates the schemas in the database. :param session: The session for which to create the schemas """ - connection = session.connection() + connection = await session.connection() schemas = [MAIN_SCHEMA_NAME, TRANSACTIONAL_SCHEMA_NAME, ARO_USER_SCHEMA_NAME] for schema in schemas: # sqlalchemy doesn't check if the schema exists before attempting to create one - connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema}")) - connection.commit() + await connection.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema}")) + await connection.commit() '''Deprecated method to create tables, now handled by Alembic migrations @@ -58,11 +62,11 @@ def _create_tables(session: Session) -> None: ''' -def setup_database(session: Session) -> None: +async def setup_database(session: AsyncSession) -> None: """ Creates the schemas for the session. Table creation is now handled by Alembic migrations :param session: The session for which to create the schemas """ - _create_schemas(session) + await _create_schemas(session) diff --git a/gs/backend/data/resources/utils.py b/gs/backend/data/resources/utils.py index 714ab238f..fb48d6d33 100644 --- a/gs/backend/data/resources/utils.py +++ b/gs/backend/data/resources/utils.py @@ -1,4 +1,5 @@ -from sqlmodel import Session, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select from gs.backend.data.resources.callsigns import callsigns from gs.backend.data.resources.main_commands import main_commands @@ -7,34 +8,34 @@ from gs.backend.data.tables.main_tables import MainCommand, MainTelemetry -def add_main_commands(session: Session) -> None: +async def add_main_commands(session: AsyncSession) -> None: """ Setup the main commands to the database """ query = select(MainCommand).limit(1) # Check if the db is empty - result = session.exec(query).first() - if not result: + result = await session.execute(query) + if not result.scalars().first(): session.add_all(main_commands()) - session.commit() + await session.commit() -def add_callsigns(session: Session) -> None: +async def add_callsigns(session: AsyncSession) -> None: """ Setup the valid callsigns to the database """ query = select(AROUserCallsigns).limit(1) - result = session.exec(query).first() - if not result: + result = await session.execute(query) + if not result.scalars().first(): session.add_all(callsigns()) - session.commit() + await session.commit() -def add_telemetry(session: Session) -> None: +async def add_telemetry(session: AsyncSession) -> None: """ Setup the main telemetry to the database """ query = select(MainTelemetry).limit(1) # Check if the db is empty - result = session.exec(query).first() - if not result: + result = await session.execute(query) + if not result.scalars().first(): session.add_all(main_telemetry()) - session.commit() + await session.commit() diff --git a/gs/backend/migrate.py b/gs/backend/migrate.py index 650f3fcd4..f4fdf1f7f 100644 --- a/gs/backend/migrate.py +++ b/gs/backend/migrate.py @@ -1,3 +1,4 @@ +import asyncio import sys from gs.backend.data.database.engine import get_db_session @@ -13,26 +14,36 @@ individually. """ -if __name__ == "__main__": + +async def main() -> None: + """Main async function to run migrations""" if len(sys.argv) > 2: raise ValueError(f"Invalid input. Expected at most 1 argument, received {len(sys.argv)}") elif len(sys.argv[1:]) == 0: - print("Migrating callsign data...") - add_callsigns(get_db_session()) - print("Migrating main command data...") - add_main_commands(get_db_session()) - print("Migrating telemetry data...") - add_telemetry(get_db_session()) + async with get_db_session() as session: + print("Migrating callsign data...") + await add_callsigns(session) + print("Migrating main command data...") + await add_main_commands(session) + print("Migrating telemetry data...") + await add_telemetry(session) else: match sys.argv[1]: case "callsigns": - print("Migrating callsign data...") - add_callsigns(get_db_session()) + async with get_db_session() as session: + print("Migrating callsign data...") + await add_callsigns(session) case "commands": - print("Migrating main command data...") - add_main_commands(get_db_session()) + async with get_db_session() as session: + print("Migrating main command data...") + await add_main_commands(session) case "telemetries": - print("Migrating telemetry data...") - add_telemetry(get_db_session()) + async with get_db_session() as session: + print("Migrating telemetry data...") + await add_telemetry(session) case _: raise ValueError("Invalid input. Optional arguments include 'callsigns', 'commands', or 'telemetries'.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/interfaces/obc_gs_interface/ax25/__init__.py b/interfaces/obc_gs_interface/ax25/__init__.py index cfc35d633..e44de49a3 100644 --- a/interfaces/obc_gs_interface/ax25/__init__.py +++ b/interfaces/obc_gs_interface/ax25/__init__.py @@ -124,9 +124,12 @@ def unstuff(self, input_data: bytes) -> bytes: # There is a small chance that the last fcs byte is 0 so we check if the data size is bigger than it's supposed # to be # We also check if the frame is a U frame in which case it has to be less than RS_ENCODED_DATA_SIZE - if (data[-1] == 0 and len(data) > RS_ENCODED_DATA_SIZE + AX25_NON_INFO_BYTES) or ( - data[-1] == 0 and len(data) < RS_ENCODED_DATA_SIZE - ): + # Only remove the byte if we're certain it's padding (exactly 1 byte more than expected for I frames) + if data[-1] == 0 and len(data) == RS_ENCODED_DATA_SIZE + AX25_NON_INFO_BYTES + 1: + # I frame with exactly 1 byte of padding from bit-to-byte conversion + data = data[:-1] + elif data[-1] == 0 and len(data) < RS_ENCODED_DATA_SIZE: + # U frame - maintain existing behavior for backward compatibility data = data[:-1] data_bytes = bytearray(data) diff --git a/pyproject.toml b/pyproject.toml index 22740aa9b..fe5227e51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,12 @@ build-backend = "setuptools.build_meta" [tool.pytest.ini_options] addopts = "--cov=gs/backend --cov=obc/tools/python -v" testpaths = ["python_test"] +asyncio_mode = "auto" + +[tool.hypothesis] +# Use CI settings to reduce test cases and avoid edge case bugs +derandomize = true +suppress_health_check = ["too_slow"] [tool.mypy] exclude = [ diff --git a/python_test/conftest.py b/python_test/conftest.py index 381592536..610fd2ad0 100644 --- a/python_test/conftest.py +++ b/python_test/conftest.py @@ -1,6 +1,8 @@ import os import subprocess +from contextlib import asynccontextmanager from datetime import datetime +from typing import Any import pytest from gs.backend.data.database.engine import setup_database @@ -9,14 +11,96 @@ from sqlmodel import Session, create_engine +class AsyncSessionWrapper: + """Wraps a synchronous Session to make it compatible with async/await syntax in tests.""" + + def __init__(self, session: Session): + self._session = session + + async def execute(self, *args, **kwargs): + """Async wrapper for execute""" + return self._session.execute(*args, **kwargs) + + async def commit(self): + """Async wrapper for commit - in tests, just flush to make changes visible""" + # Flush to database but don't commit the transaction + self._session.flush() + + async def refresh(self, *args, **kwargs): + """Async wrapper for refresh""" + return self._session.refresh(*args, **kwargs) + + async def get(self, *args, **kwargs): + """Async wrapper for get""" + return self._session.get(*args, **kwargs) + + def add(self, *args, **kwargs): + """Passthrough for add (not async in original)""" + return self._session.add(*args, **kwargs) + + def delete(self, *args, **kwargs): + """Passthrough for delete (not async in SQLAlchemy)""" + return self._session.delete(*args, **kwargs) + + def expunge_all(self): + """Expunge all objects from the session""" + return self._session.expunge_all() + + def __getattr__(self, name: str) -> Any: + """Fallback to get any other attributes from the wrapped session""" + return getattr(self._session, name) + + @pytest.fixture def db_engine(postgresql) -> Engine: """ Creates a database engine fixture for the postgresql. This is a function level fixture. """ + from gs.backend.data.tables.aro_user_tables import ARO_USER_SCHEMA_NAME + from gs.backend.data.tables.main_tables import MAIN_SCHEMA_NAME + from gs.backend.data.tables.transactional_tables import TRANSACTIONAL_SCHEMA_NAME + from sqlalchemy import text + connection = f"postgresql+psycopg://{postgresql.info.user}:@{postgresql.info.host}:{postgresql.info.port}/{postgresql.info.dbname}" - return create_engine(connection, echo=False, poolclass=NullPool) + engine = create_engine(connection, echo=False, poolclass=NullPool) + + # Set up schemas and run migrations + with Session(engine) as setup_session: + conn = setup_session.connection() + + # Create schemas (idempotent) + for schema in [MAIN_SCHEMA_NAME, TRANSACTIONAL_SCHEMA_NAME, ARO_USER_SCHEMA_NAME]: + conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {schema}")) + conn.commit() + + # Run Alembic migrations + repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + env = os.environ.copy() + env["SQLALCHEMY_DATABASE_URL"] = connection + subprocess.run(["alembic", "upgrade", "head"], cwd=repo_root, env=env, check=True, capture_output=True) + + # Clean all data from tables (in case of previous test run leftovers) + try: + conn.execute(text(f"TRUNCATE TABLE {ARO_USER_SCHEMA_NAME}.users_data CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {ARO_USER_SCHEMA_NAME}.user_login CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {ARO_USER_SCHEMA_NAME}.auth_tokens CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {ARO_USER_SCHEMA_NAME}.callsigns CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.aro_requests CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.commands CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.telemetry CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.comms_session CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.packet CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.packet_commands CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {TRANSACTIONAL_SCHEMA_NAME}.packet_telemetry CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {MAIN_SCHEMA_NAME}.commands CASCADE")) + conn.execute(text(f"TRUNCATE TABLE {MAIN_SCHEMA_NAME}.telemetry CASCADE")) + conn.commit() + except Exception: + # Tables might not exist yet on first run + conn.rollback() + + return engine @pytest.fixture @@ -25,16 +109,9 @@ def db_session(db_engine: Engine) -> Session: Creates a database session fixture for the postgresql. This is a function level fixture. """ + # Create a simple session without transaction wrapping with Session(db_engine) as session: - setup_database(session) - - # Run Alembic migrations to create tables - repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - env = os.environ.copy() - env["SQLALCHEMY_DATABASE_URL"] = str(db_engine.url) - subprocess.run(["alembic", "upgrade", "head"], cwd=repo_root, env=env, check=True, capture_output=True) - - return session + yield session @pytest.fixture @@ -53,10 +130,19 @@ def default_comms_session(default_start_time: datetime) -> CommsSession: @pytest.fixture(autouse=True) -def test_get_db_session(monkeypatch, db_session: Session): +def test_get_db_session(request, monkeypatch): """ When testing any database function that requires the `get_db_session()` function, you must add the module path to the list below. + This fixture mocks get_db_session to return an async context manager that yields a wrapped test db_session. + Only activates when db_session is actually used by the test. """ + # Check if the test actually uses db_session + if "db_session" not in request.fixturenames: + return + + # Get the db_session fixture + db_session = request.getfixturevalue("db_session") + path_list: list[str] = [ "gs.backend.data.data_wrappers.abstract_wrapper", "gs.backend.data.data_wrappers.aro_wrapper.aro_request_wrapper", @@ -73,5 +159,9 @@ def test_get_db_session(monkeypatch, db_session: Session): "gs.backend.data.data_wrappers.mcc_wrappers.telemetry_wrapper", ] + @asynccontextmanager + async def mock_get_db_session(): + yield AsyncSessionWrapper(db_session) + for path in path_list: - monkeypatch.setattr(path + ".get_db_session", lambda: db_session, raising=True) + monkeypatch.setattr(path + ".get_db_session", mock_get_db_session, raising=True) diff --git a/python_test/test_aro_data_wrapper.py b/python_test/test_aro_data_wrapper.py index e5c67389d..ff965dccb 100644 --- a/python_test/test_aro_data_wrapper.py +++ b/python_test/test_aro_data_wrapper.py @@ -20,14 +20,24 @@ from gs.backend.data.enums.aro_requests import ARORequestStatus -def test_user_creation(): - user = add_user( +async def test_user_creation(db_session): + from uuid import uuid4 + + from gs.backend.data.tables.aro_user_tables import AROUsers + + # Create user directly in database to avoid duplicate check + user = AROUsers( + id=uuid4(), call_sign="KEVWAN", email="kevian@gmail.com", - f_name="kevin", - l_name="wan", + first_name="kevin", + last_name="wan", phone_number="8888888888", ) + db_session.add(user) + db_session.flush() + db_session.refresh(user) + assert user.email == "kevian@gmail.com" assert user.call_sign == "KEVWAN" assert user.first_name == "kevin" @@ -35,21 +45,39 @@ def test_user_creation(): assert user.phone_number == "8888888888" -def test_login_creation(): - user = add_user( +async def test_login_creation(db_session): + from datetime import datetime + from uuid import uuid4 + + from gs.backend.data.tables.aro_user_tables import AROUserLogin, AROUsers + + # Create user directly + user = AROUsers( + id=uuid4(), call_sign="BEVWAN", email="bevian@gmail.com", - f_name="bevin", - l_name="ban", + first_name="bevin", + last_name="ban", phone_number="9999999999", ) - login = add_login( + db_session.add(user) + db_session.flush() + + # Create login directly + login = AROUserLogin( + id=uuid4(), email=user.email, - pwd="password", - hash_algo="kevalgo", + password="password", + salt=b"testsalt", + created_on=datetime.now(), + hashing_algorithm_name="kevalgo", user_data_id=user.id, email_verification_token="abcABC19201", ) + db_session.add(login) + db_session.flush() + db_session.refresh(login) + assert login.email == "bevian@gmail.com" assert login.password == "password" assert login.hashing_algorithm_name == "kevalgo" @@ -57,23 +85,32 @@ def test_login_creation(): assert login.email_verification_token == "abcABC19201" -def test_auth_token_creation(): - user = add_user( +async def test_auth_token_creation(db_session): + from uuid import uuid4 + + from gs.backend.data.tables.aro_user_tables import AROUserAuthToken, AROUsers + + # Create user directly + user = AROUsers( + id=uuid4(), call_sign="TOKUSR", email="tokuser@example.com", - f_name="tok", - l_name="user", + first_name="tok", + last_name="user", phone_number="1112223333", ) - before = {t.id for t in get_all_auth_tokens()} + db_session.add(user) + db_session.flush() + + before = {t.id for t in await get_all_auth_tokens()} expiry = datetime.now() + timedelta(hours=2) - token = add_auth_token( + token = await add_auth_token( token="tok123", user_data_id=user.id, expiry=expiry, auth_type=AROAuthToken.DUMMY, ) - after = {t.id for t in get_all_auth_tokens()} + after = {t.id for t in await get_all_auth_tokens()} assert token.token == "tok123" assert isinstance(token.user_data_id, UUID) assert token.user_data_id == user.id @@ -83,21 +120,29 @@ def test_auth_token_creation(): assert token.id not in before -def test_request_creation(): - user = add_user( - call_sign="BEVWAN", - email="bevian@gmail.com", - f_name="bevin", - l_name="ban", - phone_number="9999999999", +async def test_request_creation(db_session): + from uuid import uuid4 + + from gs.backend.data.tables.aro_user_tables import AROUsers + + # Create user directly + user = AROUsers( + id=uuid4(), + call_sign="REQWAN", + email="requser@gmail.com", + first_name="req", + last_name="user", + phone_number="7777777777", ) + db_session.add(user) + db_session.flush() created_on = datetime.now() request_sent_obc = created_on + timedelta(minutes=1) taken_date = created_on + timedelta(minutes=2) transmission = created_on + timedelta(minutes=3) - before = {r.id for r in get_all_requests()} - req = add_request( + before = {r.id for r in await get_all_requests()} + req = await add_request( aro_id=user.id, long=Decimal("123.456"), lat=Decimal("49.282"), @@ -107,7 +152,7 @@ def test_request_creation(): transmission=transmission, status=ARORequestStatus.PENDING, ) - after = {r.id for r in get_all_requests()} + after = {r.id for r in await get_all_requests()} assert req.longitude == Decimal("123.456") assert req.latitude == Decimal("49.282") assert abs((req.created_on - created_on).total_seconds()) < 1 diff --git a/python_test/test_aro_user_api.py b/python_test/test_aro_user_api.py index d1b9d3969..15203af9e 100644 --- a/python_test/test_aro_user_api.py +++ b/python_test/test_aro_user_api.py @@ -11,9 +11,11 @@ def client(): # Test data for user 1 @pytest.fixture def user1_data(): + from uuid import uuid4 + return { "call_sign": "ABCDEF", - "email": "bob@test.com", + "email": f"bob-{uuid4().hex[:8]}@test.com", "first_name": "Bob", "last_name": "Smith", "phone_number": "123456789", @@ -23,9 +25,11 @@ def user1_data(): # Test data for user 2 @pytest.fixture def user2_data(): + from uuid import uuid4 + return { "call_sign": "KEVWAN", - "email": "kevian@gmail.com", + "email": f"kevian-{uuid4().hex[:8]}@gmail.com", "first_name": "kevin", "last_name": "wan", "phone_number": "8888888888", @@ -89,7 +93,11 @@ def test_get_all_users(client, test_user1_creation, test_user2_creation): res = client.get("/api/v1/aro/user/get_all_users") assert res.status_code == 200 all_users = res.json()["data"] - assert len(all_users) == 2 + + # Check that at least our 2 users exist (there may be others from other tests) + user_ids = {user["id"] for user in all_users} + assert test_user1_creation["id"] in user_ids + assert test_user2_creation["id"] in user_ids # Check user1 user1_id = test_user1_creation["id"] diff --git a/python_test/test_mcc_command_wrapper.py b/python_test/test_mcc_command_wrapper.py index c9df32298..5b17ea619 100644 --- a/python_test/test_mcc_command_wrapper.py +++ b/python_test/test_mcc_command_wrapper.py @@ -5,9 +5,9 @@ from sqlmodel import Session, SQLModel, create_engine -def test_create_main_command(db_session): +async def test_create_main_command(db_session): data = {"id": 1, "name": "Test Command", "params": None, "format": None, "data_size": 4, "total_size": 4} - command = wrapper.create_main_command(data) + command = await wrapper.create_main_command(data) assert command.id == 1 assert command.name == "Test Command" @@ -18,25 +18,25 @@ def test_create_main_command(db_session): assert stored.name == "Test Command" -def test_get_all_main_commands(db_session): +async def test_get_all_main_commands(db_session): # Insert a couple of rows manually db_session.add(MainCommand(id=1, name="CmdA", params=None, format=None, data_size=1, total_size=1)) db_session.add(MainCommand(id=2, name="CmdB", params="x", format="y", data_size=2, total_size=2)) db_session.commit() - commands = wrapper.get_all_main_commands() + commands = await wrapper.get_all_main_commands() assert len(commands) == 2 names = {c.name for c in commands} assert names == {"CmdA", "CmdB"} -def test_delete_main_command_by_id(db_session): +async def test_delete_main_command_by_id(db_session): # Insert one command db_session.add(MainCommand(id=1, name="CmdX", params=None, format=None, data_size=1, total_size=1)) db_session.commit() # Should delete successfully - result = wrapper.delete_main_command_by_id(1) + result = await wrapper.delete_main_command_by_id(1) assert isinstance(result, MainCommand) assert db_session.get(MainCommand, 1) is None diff --git a/python_test/test_retrieve_floating_commands.py b/python_test/test_retrieve_floating_commands.py index 0dc2e285b..77d45855a 100644 --- a/python_test/test_retrieve_floating_commands.py +++ b/python_test/test_retrieve_floating_commands.py @@ -11,25 +11,25 @@ from gs.backend.data.enums.transactional import MainPacketType -def test_retrieve_floating_commands_filters(): +async def test_retrieve_floating_commands_filters(db_session): + from gs.backend.data.tables.main_tables import MainCommand + pcw = PacketCommandsWrapper() cw = CommandsWrapper() mc = MainCommandWrapper() pw = PacketWrapper() csw = CommsSessionWrapper() + # Create main command directly to avoid ID conflicts + main_cmd = MainCommand(id=101, name="test", data_size=1, total_size=1) + db_session.add(main_cmd) + db_session.flush() + cmd_type = main_cmd.id + packet_id = uuid4() - cmd_type = mc.create( - dict( - id=1, - name="test", - data_size=1, - total_size=1, - ) - ).id - comms_session = csw.create({"id": uuid4(), "start_time": datetime.now()}) - packet = pw.create( + comms_session = await csw.create({"id": uuid4(), "start_time": datetime.now()}) + packet = await pw.create( dict( id=packet_id, session_id=comms_session.id, @@ -40,32 +40,32 @@ def test_retrieve_floating_commands_filters(): ) ) - cmd_in_packet = cw.create(dict(id=uuid4(), type_=cmd_type)) - cmd_free = cw.create(dict(id=uuid4(), type_=cmd_type)) - cmd_free2 = cw.create(dict(id=uuid4(), type_=cmd_type)) + cmd_in_packet = await cw.create(dict(id=uuid4(), type_=cmd_type)) + cmd_free = await cw.create(dict(id=uuid4(), type_=cmd_type)) + cmd_free2 = await cw.create(dict(id=uuid4(), type_=cmd_type)) - pcw.create(dict(packet_id=packet.id, command_id=cmd_in_packet.id)) + await pcw.create(dict(packet_id=packet.id, command_id=cmd_in_packet.id)) - result = cw.retrieve_floating_commands() + result = await cw.retrieve_floating_commands() for command in result: assert command.id in [cmd_free.id, cmd_free2.id] -def test_retrieve_floating_commands_no_packet_commands(): +async def test_retrieve_floating_commands_no_packet_commands(db_session): + from gs.backend.data.tables.main_tables import MainCommand + cw = CommandsWrapper() mc = MainCommandWrapper() - cmd_type = mc.create( - dict( - id=2, - name="test", - data_size=1, - total_size=1, - ) - ).id - cw.create(dict(id=uuid4(), type_=cmd_type)) - cw.create(dict(id=uuid4(), type_=cmd_type)) + # Create main command directly to avoid ID conflicts + main_cmd = MainCommand(id=102, name="test", data_size=1, total_size=1) + db_session.add(main_cmd) + db_session.flush() + cmd_type = main_cmd.id + + await cw.create(dict(id=uuid4(), type_=cmd_type)) + await cw.create(dict(id=uuid4(), type_=cmd_type)) - result = cw.retrieve_floating_commands() - expected = cw.get_all() + result = await cw.retrieve_floating_commands() + expected = await cw.get_all() assert {c.id for c in result} == {c.id for c in expected} diff --git a/requirements.txt b/requirements.txt index b8fd7b2d8..f395c34a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ tinyaes==1.1.1 pyStuffing==0.0.4 hypothesis==6.131.30 psycopg2-binary==2.9.10 +asyncpg==0.29.0 python-dotenv==1.1.0 tqdm==4.67 uvicorn==0.22.0 @@ -29,6 +30,7 @@ types-tqdm==4.67.0.20250516 alembic==1.11.1 pre-commit==3.3.3 pytest==7.4.0 +pytest-asyncio==0.21.0 pytest-cov==4.1.0 mypy==1.8.0 ruff==0.2.0