1- from module .database .combine import Database
1+ import pytest
2+ from sqlalchemy .ext .asyncio import AsyncSession , create_async_engine
3+ from sqlalchemy .orm import sessionmaker
4+ from sqlmodel import SQLModel
5+
6+ from module .database .bangumi import BangumiDatabase
7+ from module .database .rss import RSSDatabase
8+ from module .database .torrent import TorrentDatabase
29from module .models import Bangumi , RSSItem , Torrent
3- from sqlmodel import SQLModel , create_engine
4- from sqlmodel .pool import StaticPool
510
6- # sqlite mock engine
7- engine = create_engine (
8- "sqlite://" , connect_args = {"check_same_thread" : False }, poolclass = StaticPool
11+ # sqlite async mock engine
12+ engine = create_async_engine (
13+ "sqlite+aiosqlite://" ,
14+ echo = False ,
915)
16+ async_session_factory = sessionmaker (engine , class_ = AsyncSession , expire_on_commit = False )
17+
18+
19+ @pytest .fixture
20+ async def db_session ():
21+ async with engine .begin () as conn :
22+ await conn .run_sync (SQLModel .metadata .create_all )
23+ async with async_session_factory () as session :
24+ yield session
25+ async with engine .begin () as conn :
26+ await conn .run_sync (SQLModel .metadata .drop_all )
1027
1128
12- def test_bangumi_database ():
29+ @pytest .mark .asyncio
30+ async def test_bangumi_database (db_session ):
1331 test_data = Bangumi (
1432 official_title = "无职转生,到了异世界就拿出真本事" ,
1533 year = "2021" ,
@@ -30,49 +48,60 @@ def test_bangumi_database():
3048 save_path = "downloads/无职转生,到了异世界就拿出真本事/Season 1" ,
3149 deleted = False ,
3250 )
33- with Database (engine ) as db :
34- db .create_table ()
35- # insert
36- db .bangumi .add (test_data )
37- assert db .bangumi .search_id (1 ) == test_data
51+ db = BangumiDatabase (db_session )
3852
39- # update
40- test_data . official_title = "无职转生,到了异世界就拿出真本事II"
41- db .bangumi . update ( test_data )
42- assert db . bangumi . search_id ( 1 ) == test_data
53+ # insert
54+ await db . add ( test_data )
55+ result = await db .search_id ( 1 )
56+ assert result . official_title == test_data . official_title
4357
44- # search poster
45- assert db .bangumi .match_poster ("无职转生,到了异世界就拿出真本事II (2021)" ) == "/test/test.jpg"
58+ # update
59+ test_data .official_title = "无职转生,到了异世界就拿出真本事II"
60+ await db .update (test_data )
61+ result = await db .search_id (1 )
62+ assert result .official_title == test_data .official_title
4663
47- # match torrent
48- result = db .bangumi .match_torrent (
49- "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
50- )
51- assert result .official_title == "无职转生,到了异世界就拿出真本事II"
64+ # search poster
65+ poster = await db .match_poster ("无职转生,到了异世界就拿出真本事II (2021)" )
66+ assert poster == "/test/test.jpg"
5267
53- # delete
54- db .bangumi .delete_one (1 )
55- assert db .bangumi .search_id (1 ) is None
68+ # match torrent
69+ result = await db .match_torrent (
70+ "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]"
71+ )
72+ assert result .official_title == "无职转生,到了异世界就拿出真本事II"
73+
74+ # delete
75+ await db .delete_one (1 )
76+ result = await db .search_id (1 )
77+ assert result is None
5678
5779
58- def test_torrent_database ():
80+ @pytest .mark .asyncio
81+ async def test_torrent_database (db_session ):
5982 test_data = Torrent (
6083 name = "[Sub Group]test S02 01 [720p].mkv" ,
6184 url = "https://test.com/test.mkv" ,
6285 )
63- with Database (engine ) as db :
64- # insert
65- db .torrent .add (test_data )
66- assert db .torrent .search (1 ) == test_data
86+ db = TorrentDatabase (db_session )
87+
88+ # insert
89+ await db .add (test_data )
90+ result = await db .search (1 )
91+ assert result .name == test_data .name
6792
68- # update
69- test_data .downloaded = True
70- db .torrent .update (test_data )
71- assert db .torrent .search (1 ) == test_data
93+ # update
94+ test_data .downloaded = True
95+ await db .update (test_data )
96+ result = await db .search (1 )
97+ assert result .downloaded == True
7298
7399
74- def test_rss_database ():
100+ @pytest .mark .asyncio
101+ async def test_rss_database (db_session ):
75102 rss_url = "https://test.com/test.xml"
103+ db = RSSDatabase (db_session )
76104
77- with Database (engine ) as db :
78- db .rss .add (RSSItem (url = rss_url ))
105+ await db .add (RSSItem (url = rss_url , name = "Test RSS" ))
106+ result = await db .search_id (1 )
107+ assert result .url == rss_url
0 commit comments