Skip to content
This repository was archived by the owner on Sep 6, 2024. It is now read-only.

Fix: sqlalchemy errors break session #75

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/db_adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
list_remove_item_safe,
refresh_list_items,
)
from .media import series_get
from .misc import refresh
from .user import user_create_list, user_get, user_get_list_safe, user_get_safe

__all__ = [
Expand All @@ -21,4 +23,6 @@
"refresh_list_items",
"get_list_item",
"list_remove_item_safe",
"refresh",
"series_get",
]
79 changes: 46 additions & 33 deletions src/db_adapters/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,39 +32,47 @@ async def list_put_item(

:raises ValueError: If the item is already present in the list.
"""
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
if await session.get(UserListItem, (user_list.id, tvdb_id, kind)) is not None:
raise ValueError(f"Item {tvdb_id} is already in list {user_list.id}.")

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item
async with session:
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
if await session.get(UserListItem, (user_list.id, tvdb_id, kind)) is not None:
raise ValueError(f"Item {tvdb_id} is already in list {user_list.id}.")

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item


async def list_get_item(
session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind
) -> UserListItem | None:
"""Get an item from a user list."""
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))
async with session:
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))


async def list_remove_item(session: AsyncSession, user_list: UserList, item: UserListItem) -> None:
async def list_remove_item(session: AsyncSession, user_list: UserList, item: UserListItem) -> UserList:
"""Remove an item from a user list."""
await session.delete(item)
await session.commit()
await session.refresh(user_list, ["items"])
async with session:
item = await session.merge(item)
user_list = await session.merge(user_list)
await session.delete(item)
await session.commit()
await session.refresh(user_list, ["items"])
return user_list


async def list_remove_item_safe(
session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind
) -> None:
) -> UserList:
"""Removes an item from a user list if it exists."""
if item := await list_get_item(session, user_list, tvdb_id, kind):
await list_remove_item(session, user_list, item)
async with session:
if item := await list_get_item(session, user_list, tvdb_id, kind):
return await list_remove_item(session, user_list, item)
return user_list


@overload
Expand All @@ -90,23 +98,27 @@ async def list_put_item_safe(
session: AsyncSession, user_list: UserList, tvdb_id: int, kind: UserListItemKind, series_id: int | None = None
) -> UserListItem:
"""Add an item to a user list, or return the existing item if it is already present."""
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
item = await list_get_item(session, user_list, tvdb_id, kind)
if item:
async with session:
if series_id:
await ensure_media(session, tvdb_id, kind, series_id=series_id)
else:
await ensure_media(session, tvdb_id, kind)
item = await list_get_item(session, user_list, tvdb_id, kind)
if item:
return item

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item

item = UserListItem(list_id=user_list.id, tvdb_id=tvdb_id, kind=kind)
session.add(item)
await session.commit()
return item


async def refresh_list_items(session: AsyncSession, user_list: UserList) -> None:
async def refresh_list_items(session: AsyncSession, user_list: UserList) -> UserList:
"""Refresh the items in a user list."""
await session.refresh(user_list, ["items"])
async with session:
user_list = await session.merge(user_list)
await session.refresh(user_list, ["items"])
return user_list


async def get_list_item(
Expand All @@ -116,4 +128,5 @@ async def get_list_item(
kind: UserListItemKind,
) -> UserListItem | None:
"""Get a user list."""
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))
async with session:
return await session.get(UserListItem, (user_list.id, tvdb_id, kind))
43 changes: 25 additions & 18 deletions src/db_adapters/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@

async def ensure_media(session: AsyncSession, tvdb_id: int, kind: UserListItemKind, **kwargs: Any) -> None:
"""Ensure that a tvdb media item is present in its respective table."""
match kind:
case UserListItemKind.MOVIE:
cls = Movie
case UserListItemKind.SERIES:
cls = Series
case UserListItemKind.EPISODE:
cls = Episode
media = await session.get(cls, tvdb_id)
if media is None:
media = cls(tvdb_id=tvdb_id, **kwargs)
session.add(media)
await session.commit()

if isinstance(media, Episode):
await session.refresh(media, ["series"])
if not media.series:
series = Series(tvdb_id=kwargs["series_id"])
session.add(series)
async with session:
match kind:
case UserListItemKind.MOVIE:
cls = Movie
case UserListItemKind.SERIES:
cls = Series
case UserListItemKind.EPISODE:
cls = Episode
media = await session.get(cls, tvdb_id)
if media is None:
media = cls(tvdb_id=tvdb_id, **kwargs)
session.add(media)
await session.commit()

if isinstance(media, Episode):
await session.refresh(media, ["series"])
if not media.series:
series = Series(tvdb_id=kwargs["series_id"])
session.add(series)
await session.commit()


async def series_get(session: AsyncSession, tvdb_id: int) -> Series | None:
"""Get a series from the database."""
async with session:
return await session.get(Series, tvdb_id)
9 changes: 9 additions & 0 deletions src/db_adapters/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sqlalchemy.ext.asyncio import AsyncSession


async def refresh[T](session: AsyncSession, item: T, fields: list[str]) -> T:
"""Refresh a media item with the specified fields."""
async with session:
item = await session.merge(item)
await session.refresh(item, fields)
return item
51 changes: 28 additions & 23 deletions src/db_adapters/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,33 @@

async def user_get(session: AsyncSession, discord_id: int) -> User | None:
"""Get a user by their Discord ID."""
return await session.get(User, discord_id)
async with session:
return await session.get(User, discord_id)


async def user_get_safe(session: AsyncSession, discord_id: int) -> User:
"""Get a user by their Discord ID, creating them if they don't exist."""
user = await user_get(session, discord_id)
if user is None:
user = User(discord_id=discord_id)
session.add(user)
await session.commit()
async with session:
user = await user_get(session, discord_id)
if user is None:
user = User(discord_id=discord_id)
session.add(user)
await session.commit()

return user


async def user_get_list(session: AsyncSession, user: User, name: str) -> UserList | None:
"""Get a user's list by name."""
# use where clause on user.id and name
user_list = await session.execute(
select(UserList)
.where(
UserList.user_id == user.discord_id,
async with session:
user_list = await session.execute(
select(UserList)
.where(
UserList.user_id == user.discord_id,
)
.where(UserList.name == name)
)
.where(UserList.name == name)
)
return user_list.scalars().first()


Expand All @@ -39,14 +42,15 @@ async def user_create_list(session: AsyncSession, user: User, name: str, item_ki

:raises ValueError: If a list with the same name already exists for the user.
"""
if await user_get_list(session, user, name) is not None:
raise ValueError(f"List with name {name} already exists for user {user.discord_id}.")
user_list = UserList(user_id=user.discord_id, name=name, item_kind=item_kind)
session.add(user_list)
await session.commit()
await session.refresh(user, ["lists"])
async with session:
if await user_get_list(session, user, name) is not None:
raise ValueError(f"List with name {name} already exists for user {user.discord_id}.")
user_list = UserList(user_id=user.discord_id, name=name, item_kind=item_kind)
session.add(user_list)
await session.commit()
await session.refresh(user, ["lists"])

return user_list
return user_list


async def user_get_list_safe(
Expand All @@ -57,8 +61,9 @@ async def user_get_list_safe(
:param kind: The kind of list to create if it doesn't exist.
:return: The user list.
"""
user_list = await user_get_list(session, user, name)
if user_list is None:
user_list = await user_create_list(session, user, name, kind)
async with session:
user_list = await user_get_list(session, user, name)
if user_list is None:
user_list = await user_create_list(session, user, name, kind)

return user_list
return user_list
2 changes: 1 addition & 1 deletion src/exts/tvdb_info/ui/episode_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def set_watched(self, state: bool) -> None:
)
if item is None:
raise ValueError("Episode is not marked as watched, can't re-mark as unwatched.")
await list_remove_item(self.bot.db_session, self.watched_list, item)
self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item)
else:
try:
await list_put_item(
Expand Down
8 changes: 4 additions & 4 deletions src/exts/tvdb_info/ui/movie_series_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def set_favorite(self, state: bool) -> None:
item = await get_list_item(self.bot.db_session, self.favorite_list, self.media_data.id, self._db_item_kind)
if item is None:
raise ValueError("Media is not marked as favorite, can't re-mark as favorite.")
await list_remove_item(self.bot.db_session, self.watched_list, item)
self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item)
else:
try:
await list_put_item(self.bot.db_session, self.favorite_list, self.media_data.id, self._db_item_kind)
Expand All @@ -85,7 +85,7 @@ async def set_watched(self, state: bool) -> None:
item = await get_list_item(self.bot.db_session, self.watched_list, self.media_data.id, self._db_item_kind)
if item is None:
raise ValueError("Media is not marked as watched, can't re-mark as unwatched.")
await list_remove_item(self.bot.db_session, self.watched_list, item)
self.watched_list = await list_remove_item(self.bot.db_session, self.watched_list, item)
else:
try:
await list_put_item(self.bot.db_session, self.watched_list, self.media_data.id, self._db_item_kind)
Expand Down Expand Up @@ -213,14 +213,14 @@ async def set_watched(self, state: bool) -> None:
if not episode.id:
raise ValueError("Episode has no ID")

await list_remove_item_safe(
self.watched_list = await list_remove_item_safe(
self.bot.db_session,
self.watched_list,
episode.id,
UserListItemKind.EPISODE,
)

await refresh_list_items(self.bot.db_session, self.watched_list)
self.watched_list = await refresh_list_items(self.bot.db_session, self.watched_list)
else:
for episode in self.media_data.episodes:
if not episode.id:
Expand Down
Loading
Loading