Skip to content

Commit b1018d4

Browse files
EstrellaXDclaudehappy-otter
committed
fix(test): use sync sessions in database tests to match production code
The database classes (BangumiDatabase, TorrentDatabase, RSSDatabase) use synchronous Session, but tests were incorrectly using AsyncSession with await calls, causing AttributeError on coroutine objects. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
1 parent 2a3a534 commit b1018d4

1 file changed

Lines changed: 24 additions & 35 deletions

File tree

backend/src/test/test_database.py

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,24 @@
11
import pytest
2-
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
3-
from sqlalchemy.orm import sessionmaker
4-
from sqlmodel import SQLModel
2+
from sqlmodel import Session, SQLModel, create_engine
53

64
from module.database.bangumi import BangumiDatabase
75
from module.database.rss import RSSDatabase
86
from module.database.torrent import TorrentDatabase
97
from module.models import Bangumi, RSSItem, Torrent
108

11-
# sqlite async mock engine
12-
engine = create_async_engine(
13-
"sqlite+aiosqlite://",
14-
echo=False,
15-
)
16-
async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
9+
# sqlite sync engine for testing
10+
engine = create_engine("sqlite://", echo=False)
1711

1812

1913
@pytest.fixture
20-
async def db_session():
21-
async with engine.begin() as conn:
22-
await conn.run_sync(SQLModel.metadata.create_all)
23-
async with async_session_factory() as session:
14+
def db_session():
15+
SQLModel.metadata.create_all(engine)
16+
with Session(engine) as session:
2417
yield session
25-
async with engine.begin() as conn:
26-
await conn.run_sync(SQLModel.metadata.drop_all)
18+
SQLModel.metadata.drop_all(engine)
2719

2820

29-
@pytest.mark.asyncio
30-
async def test_bangumi_database(db_session):
21+
def test_bangumi_database(db_session):
3122
test_data = Bangumi(
3223
official_title="无职转生,到了异世界就拿出真本事",
3324
year="2021",
@@ -51,57 +42,55 @@ async def test_bangumi_database(db_session):
5142
db = BangumiDatabase(db_session)
5243

5344
# insert
54-
await db.add(test_data)
55-
result = await db.search_id(1)
45+
db.add(test_data)
46+
result = db.search_id(1)
5647
assert result.official_title == test_data.official_title
5748

5849
# update
5950
test_data.official_title = "无职转生,到了异世界就拿出真本事II"
60-
await db.update(test_data)
61-
result = await db.search_id(1)
51+
db.update(test_data)
52+
result = db.search_id(1)
6253
assert result.official_title == test_data.official_title
6354

6455
# search poster
65-
poster = await db.match_poster("无职转生,到了异世界就拿出真本事II (2021)")
56+
poster = db.match_poster("无职转生,到了异世界就拿出真本事II (2021)")
6657
assert poster == "/test/test.jpg"
6758

6859
# match torrent
69-
result = await db.match_torrent(
60+
result = db.match_torrent(
7061
"[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
7162
)
7263
assert result.official_title == "无职转生,到了异世界就拿出真本事II"
7364

7465
# delete
75-
await db.delete_one(1)
76-
result = await db.search_id(1)
66+
db.delete_one(1)
67+
result = db.search_id(1)
7768
assert result is None
7869

7970

80-
@pytest.mark.asyncio
81-
async def test_torrent_database(db_session):
71+
def test_torrent_database(db_session):
8272
test_data = Torrent(
8373
name="[Sub Group]test S02 01 [720p].mkv",
8474
url="https://test.com/test.mkv",
8575
)
8676
db = TorrentDatabase(db_session)
8777

8878
# insert
89-
await db.add(test_data)
90-
result = await db.search(1)
79+
db.add(test_data)
80+
result = db.search(1)
9181
assert result.name == test_data.name
9282

9383
# update
9484
test_data.downloaded = True
95-
await db.update(test_data)
96-
result = await db.search(1)
85+
db.update(test_data)
86+
result = db.search(1)
9787
assert result.downloaded == True
9888

9989

100-
@pytest.mark.asyncio
101-
async def test_rss_database(db_session):
90+
def test_rss_database(db_session):
10291
rss_url = "https://test.com/test.xml"
10392
db = RSSDatabase(db_session)
10493

105-
await db.add(RSSItem(url=rss_url, name="Test RSS"))
106-
result = await db.search_id(1)
94+
db.add(RSSItem(url=rss_url, name="Test RSS"))
95+
result = db.search_id(1)
10796
assert result.url == rss_url

0 commit comments

Comments
 (0)