Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gs/backend/api/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:

# Must all the get_db_session each time when pass it into a separate function.
# Otherwise, will get transaction is inactive error
setup_database(get_db_session())
async with get_db_session() as session:
await setup_database(session)
yield
14 changes: 7 additions & 7 deletions gs/backend/api/v1/aro/endpoints/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,31 @@ async def get_all_users() -> AllUsersResponse:

:return: all users
"""
users = AROUsersWrapper().get_all()
users = await AROUsersWrapper().get_all()
return AllUsersResponse(data=users)


@aro_user_router.get("/get_user/{userid}", response_model=UserResponse)
def get_user(userid: str) -> UserResponse:
async def get_user(userid: str) -> UserResponse:
"""
Gets a user by ID

:param userid: The unique identifier of the user
:return: the user
"""
user = AROUsersWrapper().get_by_id(UUID(userid))
user = await AROUsersWrapper().get_by_id(UUID(userid))
return UserResponse(data=user)


@aro_user_router.post("/create_user", response_model=UserResponse)
def create_user(payload: UserRequest) -> UserResponse:
async def create_user(payload: UserRequest) -> UserResponse:
"""
Creates a user with the given payload
:param payload: The data used to create a user
:return: returns the user created
"""

user = AROUsersWrapper().create(
user = await AROUsersWrapper().create(
data={
"call_sign": payload.call_sign,
"email": payload.email,
Expand All @@ -54,11 +54,11 @@ def create_user(payload: UserRequest) -> UserResponse:


@aro_user_router.delete("/delete_user/{userid}", response_model=UserResponse)
def delete_user(userid: str) -> UserResponse:
async def delete_user(userid: str) -> UserResponse:
"""
Deletes a user based on the user ID
:param userid: The unique identifier of the user to be deleted
:return: returns the deleted user
"""
deleted_user = AROUsersWrapper().delete_by_id(UUID(userid))
deleted_user = await AROUsersWrapper().delete_by_id(UUID(userid))
return UserResponse(data=deleted_user)
4 changes: 2 additions & 2 deletions gs/backend/api/v1/mcc/endpoints/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def create_command(payload: dict[str, Any]) -> Commands:
:param payload: The data used to create a command
:return: returns the created command object
"""
return create_commands(payload)
return await create_commands(payload)


@commands_router.delete("/delete/{command_id}")
Expand All @@ -31,5 +31,5 @@ async def delete_command(command_id: UUID) -> dict[str, Any]:
:param command_id: The id which is to be deleted.
:return: returns a dict giving confirmation that command with id of command_id has been deleted.
"""
delete_commands_by_id(command_id)
await delete_commands_by_id(command_id)
return {"message": f"Command with id {command_id} deleted successfully"}
2 changes: 1 addition & 1 deletion gs/backend/api/v1/mcc/endpoints/main_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ async def get_main_commands() -> MainCommandsResponse:

:return: list of all commands
"""
items = get_all_main_commands()
items = await get_all_main_commands()
return MainCommandsResponse(data=items)
2 changes: 1 addition & 1 deletion gs/backend/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def getenv(config: str) -> str:

DATABASE_CONNECTION_STRING: Final[
str
] = f"postgresql+psycopg2://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}"
] = f"postgresql+asyncpg://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this allow for the db to be asynchronous? i thought it already was asynchronous

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new one is the async driver for psql connection

29 changes: 15 additions & 14 deletions gs/backend/data/data_wrappers/abstract_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,54 @@ class AbstractWrapper(ABC, Generic[T, PK]):

model: type[T]

def get_all(self) -> list[T]:
async def get_all(self) -> list[T]:
"""
Get all data wrapper for the unspecified model

:return: a list of all model instances
"""
with get_db_session() as session:
return list(session.exec(select(self.model)).all())
async with get_db_session() as session:
result = await session.exec(select(self.model))
return list(result.all())

def get_by_id(self, obj_id: PK) -> T:
async def get_by_id(self, obj_id: PK) -> T:
"""
Retrieve data wrapper for the unspecified model

:param obj_id: PK of the model instance to be retrieved
:return: the retrieved instance
"""
with get_db_session() as session:
obj = session.get(self.model, obj_id)
async with get_db_session() as session:
obj = await session.get(self.model, obj_id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.")
return obj

def create(self, data: dict[str, Any]) -> T:
async def create(self, data: dict[str, Any]) -> T:
"""
Post data wrapper for the unspecified model

:param data: the JSON object of the model instance to be created
:return: the newly created instance
"""
with get_db_session() as session:
async with get_db_session() as session:
obj = self.model(**data)
session.add(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean since we're going nuts and awaiting everything, does this need await as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, though looking at it, we could just add session: AsyncSession = Depends(get_db_session) into the function args of the abstract wrapper

session.commit()
session.refresh(obj)
await session.commit()
await session.refresh(obj)
return obj

def delete_by_id(self, obj_id: PK) -> T:
async def delete_by_id(self, obj_id: PK) -> T:
"""
Delete data wrapper for the unspecified model

:param obj_id: PK of the model instance to be deleted
:return: the deleted instance
"""
with get_db_session() as session:
obj = session.get(self.model, obj_id)
async with get_db_session() as session:
obj = await session.get(self.model, obj_id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.")
session.delete(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this change to return obj, not make a obj copy

session.commit()
await session.commit()
return obj
25 changes: 13 additions & 12 deletions gs/backend/data/data_wrappers/aro_wrapper/aro_request_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
from gs.backend.data.tables.transactional_tables import ARORequest


def get_all_requests() -> list[ARORequest]:
async def get_all_requests() -> list[ARORequest]:
"""
Get all the requests from aro
"""
with get_db_session() as session:
requests = list(session.exec(select(ARORequest)).all())
async with get_db_session() as session:
result = await session.exec(select(ARORequest))
requests = list(result.all())
return requests


def add_request(
async def add_request(
aro_id: UUID,
long: Decimal,
lat: Decimal,
Expand All @@ -38,7 +39,7 @@ def add_request(
:param taken_date: datetime object representing the date that this picture was trasmitted
:param status: the status of the request, can only be from the requets in ARORequestStatus
"""
with get_db_session() as session:
async with get_db_session() as session:
request = ARORequest(
aro_id=aro_id,
latitude=lat,
Expand All @@ -51,23 +52,23 @@ def add_request(
)

session.add(request)
session.commit()
session.refresh(request)
await session.commit()
await session.refresh(request)
return request


def delete_request_by_id(request_id: str) -> list[ARORequest]:
async def delete_request_by_id(request_id: str) -> list[ARORequest]:
"""
Delete a request based on id

:param request_id: unique identifier of the request
"""
with get_db_session() as session:
request = session.get(ARORequest, request_id)
async with get_db_session() as session:
request = await session.get(ARORequest, request_id)
if request:
session.delete(request)
session.commit()
await session.commit()
else:
raise ValueError("Request not found, ID does not exist")

return get_all_requests()
return await get_all_requests()
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from gs.backend.data.tables.aro_user_tables import AROUserAuthToken


def get_all_auth_tokens() -> list[AROUserAuthToken]:
async def get_all_auth_tokens() -> list[AROUserAuthToken]:
"""
Get all the auth tokens
"""
with get_db_session() as session:
auth_tokens = list(session.exec(select(AROUserAuthToken)).all())
async with get_db_session() as session:
result = await session.exec(select(AROUserAuthToken))
auth_tokens = list(result.all())
return auth_tokens


def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: AROEnums) -> AROUserAuthToken:
async def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: AROEnums) -> AROUserAuthToken:
"""
Add auth token to the db

Expand All @@ -26,29 +27,29 @@ def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type:
:param expiry: the date in which this token expires
:param auth_type: the type of auth token this is, can only be from AROAuthToken
"""
with get_db_session() as session:
async with get_db_session() as session:
auth_token = AROUserAuthToken(token=token, user_data_id=user_data_id, expiry=expiry, auth_type=auth_type)

session.add(auth_token)
session.commit()
session.refresh(auth_token)
await session.commit()
await session.refresh(auth_token)
return auth_token


def delete_auth_token_by_id(token_id: UUID) -> list[AROUserAuthToken]:
async def delete_auth_token_by_id(token_id: UUID) -> list[AROUserAuthToken]:
"""
Delete the auth token based on the token id

:param token_id: the unique identifier for a particular auth token
"""

with get_db_session() as session:
auth_token = session.get(AROUserAuthToken, token_id)
async with get_db_session() as session:
auth_token = await session.get(AROUserAuthToken, token_id)

if auth_token:
session.delete(auth_token)
session.commit()
await session.commit()
else:
print("Token does not exist")

return get_all_auth_tokens()
return await get_all_auth_tokens()
39 changes: 21 additions & 18 deletions gs/backend/data/data_wrappers/aro_wrapper/aro_user_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@


# selects all objects of type AROUser from db and returns them in list
def get_all_users() -> list[AROUsers]:
async def get_all_users() -> list[AROUsers]:
"""
Gets all user
"""
with get_db_session() as session:
users = list(session.exec(select(AROUsers)).all())
async with get_db_session() as session:
result = await session.exec(select(AROUsers))
users = list(result.all())
return users


# adds user to database of type AROUser then fetches the user from database
# so that the user now has an assigned ID
def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: str) -> AROUsers:
async def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: str) -> AROUsers:
"""
Add a new user to the AROUser table in database

Expand All @@ -28,9 +29,10 @@ def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number:
:param l_name: last name of user
:param phone_numer: phone number of user
"""
with get_db_session() as session:
async with get_db_session() as session:
# check if the user already exists with email as it is unique
existing_user = session.exec(select(AROUsers).where(AROUsers.email == email)).first()
result = await session.exec(select(AROUsers).where(AROUsers.email == email))
existing_user = result.first()

if existing_user:
raise ValueError("User already exsits based on email")
Expand All @@ -40,13 +42,13 @@ def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number:
)

session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user


# updates user into database of type AROUser then fetches the user from database
def update_user_by_id(
async def update_user_by_id(
userid: UUID, call_sign: str, email: str, f_name: str, l_name: str, phone_number: str
) -> AROUsers:
"""
Expand All @@ -59,9 +61,10 @@ def update_user_by_id(
:param l_name: last name of user
:param phone_numer: phone number of user
"""
with get_db_session() as session:
async with get_db_session() as session:
# check if the user already exists with email as it is unique
user = session.exec(select(AROUsers).where(AROUsers.id == userid)).first()
result = await session.exec(select(AROUsers).where(AROUsers.id == userid))
user = result.first()

if not user:
raise ValueError("User does not exist based on user ID")
Expand All @@ -73,25 +76,25 @@ def update_user_by_id(
user.phone_number = phone_number

session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user


# deletes the user with given id and returns the remaining users
def delete_user_by_id(userid: UUID) -> list[AROUsers]:
async def delete_user_by_id(userid: UUID) -> list[AROUsers]:
"""
Use the user.id to delete a user from table

:param userid: identifier unique to the user
"""
with get_db_session() as session:
user = session.get(AROUsers, userid)
async with get_db_session() as session:
user = await session.get(AROUsers, userid)

if user:
session.delete(user)
session.commit()
await session.commit()
else:
raise ValueError("User ID does not exist")

return get_all_users()
return await get_all_users()
Loading
Loading