Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion gs/backend/api/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:

# Must all the get_db_session each time when pass it into a separate function.
# Otherwise, will get transaction is inactive error
setup_database(get_db_session())
async with get_db_session() as session:
await setup_database(session)
yield
14 changes: 7 additions & 7 deletions gs/backend/api/v1/aro/endpoints/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,31 @@ async def get_all_users() -> AllUsersResponse:

:return: all users
"""
users = AROUsersWrapper().get_all()
users = await AROUsersWrapper().get_all()
return AllUsersResponse(data=users)


@aro_user_router.get("/get_user/{userid}", response_model=UserResponse)
def get_user(userid: str) -> UserResponse:
async def get_user(userid: str) -> UserResponse:
"""
Gets a user by ID

:param userid: The unique identifier of the user
:return: the user
"""
user = AROUsersWrapper().get_by_id(UUID(userid))
user = await AROUsersWrapper().get_by_id(UUID(userid))
return UserResponse(data=user)


@aro_user_router.post("/create_user", response_model=UserResponse)
def create_user(payload: UserRequest) -> UserResponse:
async def create_user(payload: UserRequest) -> UserResponse:
"""
Creates a user with the given payload
:param payload: The data used to create a user
:return: returns the user created
"""

user = AROUsersWrapper().create(
user = await AROUsersWrapper().create(
data={
"call_sign": payload.call_sign,
"email": payload.email,
Expand All @@ -54,11 +54,11 @@ def create_user(payload: UserRequest) -> UserResponse:


@aro_user_router.delete("/delete_user/{userid}", response_model=UserResponse)
def delete_user(userid: str) -> UserResponse:
async def delete_user(userid: str) -> UserResponse:
"""
Deletes a user based on the user ID
:param userid: The unique identifier of the user to be deleted
:return: returns the deleted user
"""
deleted_user = AROUsersWrapper().delete_by_id(UUID(userid))
deleted_user = await AROUsersWrapper().delete_by_id(UUID(userid))
return UserResponse(data=deleted_user)
4 changes: 2 additions & 2 deletions gs/backend/api/v1/mcc/endpoints/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def create_command(payload: dict[str, Any]) -> Commands:
:param payload: The data used to create a command
:return: returns the created command object
"""
return create_commands(payload)
return await create_commands(payload)


@commands_router.delete("/delete/{command_id}")
Expand All @@ -31,5 +31,5 @@ async def delete_command(command_id: UUID) -> dict[str, Any]:
:param command_id: The id which is to be deleted.
:return: returns a dict giving confirmation that command with id of command_id has been deleted.
"""
delete_commands_by_id(command_id)
await delete_commands_by_id(command_id)
return {"message": f"Command with id {command_id} deleted successfully"}
2 changes: 1 addition & 1 deletion gs/backend/api/v1/mcc/endpoints/main_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ async def get_main_commands() -> MainCommandsResponse:

:return: list of all commands
"""
items = get_all_main_commands()
items = await get_all_main_commands()
return MainCommandsResponse(data=items)
2 changes: 1 addition & 1 deletion gs/backend/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def getenv(config: str) -> str:

DATABASE_CONNECTION_STRING: Final[
str
] = f"postgresql+psycopg2://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}"
] = f"postgresql+asyncpg://{GS_DATABASE_USER}:{GS_DATABASE_PASSWORD}@{GS_DATABASE_LOCATION}:{GS_DATABASE_PORT}/{GS_DATABASE_NAME}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this allow for the db to be asynchronous? i thought it already was asynchronous

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new one is the async driver for psql connection

35 changes: 19 additions & 16 deletions gs/backend/data/data_wrappers/abstract_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,56 @@ class AbstractWrapper(ABC, Generic[T, PK]):

model: type[T]

def get_all(self) -> list[T]:
async def get_all(self) -> list[T]:
"""
Get all data wrapper for the unspecified model

:return: a list of all model instances
"""
with get_db_session() as session:
return list(session.exec(select(self.model)).all())
async with get_db_session() as session:
result = await session.execute(select(self.model))
return list(result.scalars().all())

def get_by_id(self, obj_id: PK) -> T:
async def get_by_id(self, obj_id: PK) -> T:
"""
Retrieve data wrapper for the unspecified model

:param obj_id: PK of the model instance to be retrieved
:return: the retrieved instance
"""
with get_db_session() as session:
obj = session.get(self.model, obj_id)
async with get_db_session() as session:
obj = await session.get(self.model, obj_id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.")
return obj

def create(self, data: dict[str, Any]) -> T:
async def create(self, data: dict[str, Any]) -> T:
"""
Post data wrapper for the unspecified model

:param data: the JSON object of the model instance to be created
:return: the newly created instance
"""
with get_db_session() as session:
async with get_db_session() as session:
obj = self.model(**data)
session.add(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean since we're going nuts and awaiting everything, does this need await as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, though looking at it, we could just add session: AsyncSession = Depends(get_db_session) into the function args of the abstract wrapper

session.commit()
session.refresh(obj)
await session.commit()
await session.refresh(obj)
return obj

def delete_by_id(self, obj_id: PK) -> T:
async def delete_by_id(self, obj_id: PK) -> T:
"""
Delete data wrapper for the unspecified model

:param obj_id: PK of the model instance to be deleted
:return: the deleted instance
"""
with get_db_session() as session:
obj = session.get(self.model, obj_id)
async with get_db_session() as session:
obj = await session.get(self.model, obj_id)
if not obj:
raise ValueError(f"{self.model.__name__} with ID {obj_id} not found.")
session.delete(obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this change to return obj, not make a obj copy

session.commit()
return obj
# Preserve object state before deleting
obj_copy = self.model(**{c.name: getattr(obj, c.name) for c in obj.__table__.columns}) # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use session.delete(obj)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah cursor bugged out here thinking that db deletion deletes the object from the memory, though it doesnt. pls return obj, not obj_copy

session.delete(obj) # type: ignore[unused-coroutine]
await session.commit()
return obj_copy
27 changes: 14 additions & 13 deletions gs/backend/data/data_wrappers/aro_wrapper/aro_request_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
from gs.backend.data.tables.transactional_tables import ARORequest


def get_all_requests() -> list[ARORequest]:
async def get_all_requests() -> list[ARORequest]:
"""
Get all the requests from aro
"""
with get_db_session() as session:
requests = list(session.exec(select(ARORequest)).all())
async with get_db_session() as session:
result = await session.execute(select(ARORequest))
requests = list(result.scalars().all())
return requests


def add_request(
async def add_request(
aro_id: UUID,
long: Decimal,
lat: Decimal,
Expand All @@ -38,7 +39,7 @@ def add_request(
:param taken_date: datetime object representing the date that this picture was trasmitted
:param status: the status of the request, can only be from the requets in ARORequestStatus
"""
with get_db_session() as session:
async with get_db_session() as session:
request = ARORequest(
aro_id=aro_id,
latitude=lat,
Expand All @@ -51,23 +52,23 @@ def add_request(
)

session.add(request)
session.commit()
session.refresh(request)
await session.commit()
await session.refresh(request)
return request


def delete_request_by_id(request_id: str) -> list[ARORequest]:
async def delete_request_by_id(request_id: str) -> list[ARORequest]:
"""
Delete a request based on id

:param request_id: unique identifier of the request
"""
with get_db_session() as session:
request = session.get(ARORequest, request_id)
async with get_db_session() as session:
request = await session.get(ARORequest, request_id)
if request:
session.delete(request)
session.commit()
session.delete(request) # type: ignore[unused-coroutine]
await session.commit()
else:
raise ValueError("Request not found, ID does not exist")

return get_all_requests()
return await get_all_requests()
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from gs.backend.data.tables.aro_user_tables import AROUserAuthToken


def get_all_auth_tokens() -> list[AROUserAuthToken]:
async def get_all_auth_tokens() -> list[AROUserAuthToken]:
"""
Get all the auth tokens
"""
with get_db_session() as session:
auth_tokens = list(session.exec(select(AROUserAuthToken)).all())
async with get_db_session() as session:
result = await session.execute(select(AROUserAuthToken))
auth_tokens = list(result.scalars().all())
return auth_tokens


def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: AROEnums) -> AROUserAuthToken:
async def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type: AROEnums) -> AROUserAuthToken:
"""
Add auth token to the db

Expand All @@ -26,29 +27,29 @@ def add_auth_token(token: str, user_data_id: UUID, expiry: datetime, auth_type:
:param expiry: the date in which this token expires
:param auth_type: the type of auth token this is, can only be from AROAuthToken
"""
with get_db_session() as session:
async with get_db_session() as session:
auth_token = AROUserAuthToken(token=token, user_data_id=user_data_id, expiry=expiry, auth_type=auth_type)

session.add(auth_token)
session.commit()
session.refresh(auth_token)
await session.commit()
await session.refresh(auth_token)
return auth_token


def delete_auth_token_by_id(token_id: UUID) -> list[AROUserAuthToken]:
async def delete_auth_token_by_id(token_id: UUID) -> list[AROUserAuthToken]:
"""
Delete the auth token based on the token id

:param token_id: the unique identifier for a particular auth token
"""

with get_db_session() as session:
auth_token = session.get(AROUserAuthToken, token_id)
async with get_db_session() as session:
auth_token = await session.get(AROUserAuthToken, token_id)

if auth_token:
session.delete(auth_token)
session.commit()
session.delete(auth_token) # type: ignore[unused-coroutine]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this seems more logical. also the await seems inconsistent. sometimes session uses await yet other times it does not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in terms of db connection, awaits should only be used on IO or expensive processes, so usually commit and refresh session methods

await session.commit()
else:
print("Token does not exist")

return get_all_auth_tokens()
return await get_all_auth_tokens()
41 changes: 22 additions & 19 deletions gs/backend/data/data_wrappers/aro_wrapper/aro_user_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@


# selects all objects of type AROUser from db and returns them in list
def get_all_users() -> list[AROUsers]:
async def get_all_users() -> list[AROUsers]:
"""
Gets all user
"""
with get_db_session() as session:
users = list(session.exec(select(AROUsers)).all())
async with get_db_session() as session:
result = await session.execute(select(AROUsers))
users = list(result.scalars().all())
return users


# adds user to database of type AROUser then fetches the user from database
# so that the user now has an assigned ID
def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: str) -> AROUsers:
async def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number: str) -> AROUsers:
"""
Add a new user to the AROUser table in database

Expand All @@ -28,9 +29,10 @@ def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number:
:param l_name: last name of user
:param phone_numer: phone number of user
"""
with get_db_session() as session:
async with get_db_session() as session:
# check if the user already exists with email as it is unique
existing_user = session.exec(select(AROUsers).where(AROUsers.email == email)).first()
result = await session.execute(select(AROUsers).where(AROUsers.email == email))
existing_user = result.scalars().first()

if existing_user:
raise ValueError("User already exsits based on email")
Expand All @@ -40,13 +42,13 @@ def add_user(call_sign: str, email: str, f_name: str, l_name: str, phone_number:
)

session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user


# updates user into database of type AROUser then fetches the user from database
def update_user_by_id(
async def update_user_by_id(
userid: UUID, call_sign: str, email: str, f_name: str, l_name: str, phone_number: str
) -> AROUsers:
"""
Expand All @@ -59,9 +61,10 @@ def update_user_by_id(
:param l_name: last name of user
:param phone_numer: phone number of user
"""
with get_db_session() as session:
async with get_db_session() as session:
# check if the user already exists with email as it is unique
user = session.exec(select(AROUsers).where(AROUsers.id == userid)).first()
result = await session.execute(select(AROUsers).where(AROUsers.id == userid))
user = result.scalars().first()

if not user:
raise ValueError("User does not exist based on user ID")
Expand All @@ -73,25 +76,25 @@ def update_user_by_id(
user.phone_number = phone_number

session.add(user)
session.commit()
session.refresh(user)
await session.commit()
await session.refresh(user)
return user


# deletes the user with given id and returns the remaining users
def delete_user_by_id(userid: UUID) -> list[AROUsers]:
async def delete_user_by_id(userid: UUID) -> list[AROUsers]:
"""
Use the user.id to delete a user from table

:param userid: identifier unique to the user
"""
with get_db_session() as session:
user = session.get(AROUsers, userid)
async with get_db_session() as session:
user = await session.get(AROUsers, userid)

if user:
session.delete(user)
session.commit()
session.delete(user) # type: ignore[unused-coroutine]
await session.commit()
else:
raise ValueError("User ID does not exist")

return get_all_users()
return await get_all_users()
Loading
Loading