diff --git a/backend/requirements-dev.txt b/backend/requirements-dev.txt index a65cdcbcf..94e89927f 100644 --- a/backend/requirements-dev.txt +++ b/backend/requirements-dev.txt @@ -2,4 +2,5 @@ ruff black pre-commit -pytest \ No newline at end of file +pytest +pytest-asyncio \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt index 02ed21578..d95595fdc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,9 +7,7 @@ fastapi==0.97.0 h11==0.14.0 idna==3.4 pydantic~=1.10 -PySocks==1.7.1 -qbittorrent-api==2023.9.53 -requests==2.31.0 +httpx[http2,socks]==0.25.0 six==1.16.0 sniffio==1.3.0 soupsieve==2.4.1 diff --git a/backend/src/main.py b/backend/src/main.py index da13ef57b..83cba21f4 100644 --- a/backend/src/main.py +++ b/backend/src/main.py @@ -57,7 +57,9 @@ def html(request: Request, path: str): else: context = {"request": request} return templates.TemplateResponse("index.html", context) + else: + @app.get("/", status_code=302, tags=["html"]) def index(): return RedirectResponse("/docs") diff --git a/backend/src/module/api/bangumi.py b/backend/src/module/api/bangumi.py index 3ddd62b03..b5058788e 100644 --- a/backend/src/module/api/bangumi.py +++ b/backend/src/module/api/bangumi.py @@ -116,6 +116,7 @@ async def refresh_poster(): resp = manager.refresh_poster() return u_response(resp) + @router.get( path="/refresh/poster/{bangumi_id}", response_model=APIResponse, diff --git a/backend/src/module/api/rss.py b/backend/src/module/api/rss.py index d2de16fe3..23537a899 100644 --- a/backend/src/module/api/rss.py +++ b/backend/src/module/api/rss.py @@ -4,28 +4,30 @@ from module.downloader import DownloadClient from module.manager import SeasonCollector from module.models import APIResponse, Bangumi, RSSItem, RSSUpdate, Torrent -from module.rss import RSSAnalyser, RSSEngine +from module.rss import RSSAnalyser, RSSEngine, RSSManager from module.security.api import UNAUTHORIZED, get_current_user from .response import u_response router = APIRouter(prefix="/rss", tags=["rss"]) +engine = RSSEngine() +analyser = RSSAnalyser() @router.get( path="", response_model=list[RSSItem], dependencies=[Depends(get_current_user)] ) async def get_rss(): - with RSSEngine() as engine: - return engine.rss.search_all() + with RSSManager() as manager: + return manager.rss.search_all() @router.post( path="/add", response_model=APIResponse, dependencies=[Depends(get_current_user)] ) async def add_rss(rss: RSSItem): - with RSSEngine() as engine: - result = engine.add_rss(rss.url, rss.name, rss.aggregate, rss.parser) + with RSSManager() as manager: + result = await manager.add_rss(rss.url, rss.name, rss.aggregate, rss.parser) return u_response(result) @@ -37,8 +39,8 @@ async def add_rss(rss: RSSItem): async def enable_many_rss( rss_ids: list[int], ): - with RSSEngine() as engine: - result = engine.enable_list(rss_ids) + with RSSManager() as manager: + result = manager.enable_list(rss_ids) return u_response(result) @@ -48,8 +50,8 @@ async def enable_many_rss( dependencies=[Depends(get_current_user)], ) async def delete_rss(rss_id: int): - with RSSEngine() as engine: - if engine.rss.delete(rss_id): + with RSSManager() as manager: + if manager.rss.delete(rss_id): return JSONResponse( status_code=200, content={"msg_en": "Delete RSS successfully.", "msg_zh": "删除 RSS 成功。"}, @@ -69,8 +71,8 @@ async def delete_rss(rss_id: int): async def delete_many_rss( rss_ids: list[int], ): - with RSSEngine() as engine: - result = engine.delete_list(rss_ids) + with RSSManager() as manager: + result = manager.delete_list(rss_ids) return u_response(result) @@ -80,8 +82,8 @@ async def delete_many_rss( dependencies=[Depends(get_current_user)], ) async def disable_rss(rss_id: int): - with RSSEngine() as engine: - if engine.rss.disable(rss_id): + with RSSManager() as manager: + if manager.rss.disable(rss_id): return JSONResponse( status_code=200, content={"msg_en": "Disable RSS successfully.", "msg_zh": "禁用 RSS 成功。"}, @@ -99,8 +101,8 @@ async def disable_rss(rss_id: int): dependencies=[Depends(get_current_user)], ) async def disable_many_rss(rss_ids: list[int]): - with RSSEngine() as engine: - result = engine.disable_list(rss_ids) + with RSSManager() as manager: + result = manager.disable_list(rss_ids) return u_response(result) @@ -114,8 +116,8 @@ async def update_rss( ): if not current_user: raise UNAUTHORIZED - with RSSEngine() as engine: - if engine.rss.update(rss_id, data): + with RSSManager() as manager: + if manager.rss.update(rss_id, data): return JSONResponse( status_code=200, content={"msg_en": "Update RSS successfully.", "msg_zh": "更新 RSS 成功。"}, @@ -133,8 +135,8 @@ async def update_rss( dependencies=[Depends(get_current_user)], ) async def refresh_all(): - with RSSEngine() as engine, DownloadClient() as client: - engine.refresh_rss(client) + async with DownloadClient() as client: + await engine.refresh_rss(client) return JSONResponse( status_code=200, content={"msg_en": "Refresh all RSS successfully.", "msg_zh": "刷新 RSS 成功。"}, @@ -147,8 +149,8 @@ async def refresh_all(): dependencies=[Depends(get_current_user)], ) async def refresh_rss(rss_id: int): - with RSSEngine() as engine, DownloadClient() as client: - engine.refresh_rss(client, rss_id) + async with DownloadClient() as client: + await engine.refresh_rss(client=client, rss_id=rss_id) return JSONResponse( status_code=200, content={"msg_en": "Refresh RSS successfully.", "msg_zh": "刷新 RSS 成功。"}, @@ -163,12 +165,8 @@ async def refresh_rss(rss_id: int): async def get_torrent( rss_id: int, ): - with RSSEngine() as engine: - return engine.get_rss_torrents(rss_id) - - -# Old API -analyser = RSSAnalyser() + with RSSManager() as manager: + return manager.get_rss_torrents(rss_id) @router.post( diff --git a/backend/src/module/core/aiocore.py b/backend/src/module/core/aiocore.py new file mode 100644 index 000000000..49b4b32df --- /dev/null +++ b/backend/src/module/core/aiocore.py @@ -0,0 +1,64 @@ +import asyncio + +from module.downloader import DownloadClient +from module.manager import Renamer +from module.conf import settings +from module.rss import RSSEngine +from module.database import Database +from module.models import Bangumi, RSSItem, Torrent + + + +rss_item_pool = [] +torrent_pool: list[tuple[Bangumi, list[Torrent]]] = [] + + +class AsyncProgram: + def __init__(self): + self.renamer = Renamer() + self.engine = RSSEngine() + self.event = asyncio.Event() + + async def run(self): + self.event.clear() + task = [] + if settings.bangumi_manage.enable: + task.append(self.rename_task()) + if settings.rss_parser.enable: + task.append(self.rss_task()) + await asyncio.gather(*task) + + async def rename_task(self): + while not self.event.is_set(): + async with DownloadClient() as client: + await self.check_downloader(client) + await self.renamer.rename(client) + await asyncio.sleep(settings.program.rename_time) + + async def rss_task(self): + while not self.event.is_set(): + await self.engine.rss_poller(process_rss) + await asyncio.sleep(settings.program.rss_time) + + +async def rename_task(): + connected = False + renamer = Renamer() + async with DownloadClient() as client: + while not connected: + connected = await client.auth() + if not connected: + await asyncio.sleep(30) + for bangumi, torrents in torrent_pool: + client.add_torrent(torrents, bangumi) + renamer.rename(client) + await asyncio.sleep(settings.program.rename_time) + + +async def rss_task(): + # GET RSS FROM DATABASE + with Database() as db: + rss_items = db.rss.search_active() + for rss_item in rss_items: + rss_item_pool.append(rss_item) + pass diff --git a/backend/src/module/core/program.py b/backend/src/module/core/program.py index ee73c5f58..967e9158f 100644 --- a/backend/src/module/core/program.py +++ b/backend/src/module/core/program.py @@ -2,7 +2,13 @@ from module.conf import VERSION, settings from module.models import ResponseModel -from module.update import data_migration, first_run, from_30_to_31, start_up, cache_image +from module.update import ( + data_migration, + first_run, + from_30_to_31, + start_up, + cache_image, +) from .sub_thread import RenameThread, RSSThread diff --git a/backend/src/module/core/sub_thread.py b/backend/src/module/core/sub_thread.py index 4968a2c29..34107be02 100644 --- a/backend/src/module/core/sub_thread.py +++ b/backend/src/module/core/sub_thread.py @@ -1,5 +1,6 @@ import threading import time +import asyncio from module.conf import settings from module.downloader import DownloadClient @@ -13,38 +14,11 @@ class RSSThread(ProgramStatus): def __init__(self): super().__init__() - self._rss_thread = threading.Thread( - target=self.rss_loop, - ) + self._rss_loop = asyncio.new_event_loop() self.analyser = RSSAnalyser() - def rss_loop(self): - while not self.stop_event.is_set(): - with DownloadClient() as client, RSSEngine() as engine: - # Analyse RSS - rss_list = engine.rss.search_aggregate() - for rss in rss_list: - self.analyser.rss_to_data(rss, engine) - # Run RSS Engine - engine.refresh_rss(client) - if settings.bangumi_manage.eps_complete: - eps_complete() - self.stop_event.wait(settings.program.rss_time) - - def rss_start(self): - self.rss_thread.start() - - def rss_stop(self): - if self._rss_thread.is_alive(): - self._rss_thread.join() - - @property - def rss_thread(self): - if not self._rss_thread.is_alive(): - self._rss_thread = threading.Thread( - target=self.rss_loop, - ) - return self._rss_thread + async def rss_loop(self): + pass class RenameThread(ProgramStatus): diff --git a/backend/src/module/database/bangumi.py b/backend/src/module/database/bangumi.py index b484c6b0e..ce792d15e 100644 --- a/backend/src/module/database/bangumi.py +++ b/backend/src/module/database/bangumi.py @@ -129,10 +129,11 @@ def match_list(self, torrent_list: list, rss_link: str) -> list: i += 1 return torrent_list - def match_torrent(self, torrent_name: str) -> Optional[Bangumi]: + def match_torrent(self, torrent_name: str, rss_link: str) -> Optional[Bangumi]: statement = select(Bangumi).where( and_( func.instr(torrent_name, Bangumi.title_raw) > 0, + func.instr(Bangumi.rss_link, rss_link), # use `false()` to avoid E712 checking # see: https://docs.astral.sh/ruff/rules/true-false-comparison/ Bangumi.deleted == false(), diff --git a/backend/src/module/downloader/client/qb_downloader.py b/backend/src/module/downloader/client/qb_downloader.py index fe6805f5c..c8d86a8b5 100644 --- a/backend/src/module/downloader/client/qb_downloader.py +++ b/backend/src/module/downloader/client/qb_downloader.py @@ -1,151 +1,164 @@ import logging -import time +import httpx +import asyncio -from qbittorrentapi import Client, LoginFailed -from qbittorrentapi.exceptions import ( - APIConnectionError, - Conflict409Error, - Forbidden403Error, -) - -from module.ab_decorator import qb_connect_failed_wait +from ..exceptions import ConflictError, AuthorizationError logger = logging.getLogger(__name__) +QB_API_URL = { + "login": "/api/v2/auth/login", + "logout": "/api/v2/auth/logout", + "version": "/api/v2/app/version", + "setPreferences": "/api/v2/app/setPreferences", + "createCategory": "/api/v2/torrents/createCategory", + "info": "/api/v2/torrents/info", + "add": "/api/v2/torrents/add", + "delete": "/api/v2/torrents/delete", + "renameFile": "/api/v2/torrents/renameFile", + "setLocation": "/api/v2/torrents/setLocation", + "setCategory": "/api/v2/torrents/setCategory", + "addTags": "/api/v2/torrents/addTags", +} + class QbDownloader: def __init__(self, host: str, username: str, password: str, ssl: bool): - self._client: Client = Client( - host=host, - username=username, - password=password, - VERIFY_WEBUI_CERTIFICATE=ssl, - DISABLE_LOGGING_DEBUG_OUTPUT=True, - REQUESTS_ARGS={"timeout": (3.1, 10)}, - ) - self.host = host + self.host = host if "://" in host else "http://" + host self.username = username + self.password = password + self.ssl = ssl + + async def auth(self): + resp = await self._client.post( + url=QB_API_URL["login"], + data={"username": self.username, "password": self.password}, + timeout=5, + ) + return resp.text == "Ok." + + async def logout(self): + resp = await self._client.post(url=QB_API_URL["logout"], timeout=5) + return resp.text - def auth(self, retry=3): - times = 0 - while times < retry: - try: - self._client.auth_log_in() - return True - except LoginFailed: - logger.error( - f"Can't login qBittorrent Server {self.host} by {self.username}, retry in {5} seconds." - ) - time.sleep(5) - times += 1 - except Forbidden403Error: - logger.error("Login refused by qBittorrent Server") - logger.info("Please release the IP in qBittorrent Server") - break - except APIConnectionError: - logger.error("Cannot connect to qBittorrent Server") - logger.info("Please check the IP and port in WebUI settings") - time.sleep(10) - times += 1 - except Exception as e: - logger.error(f"Unknown error: {e}") - break - return False - - def logout(self): - self._client.auth_log_out() - - def check_host(self): + async def check_host(self): try: - self._client.app_version() + await self._client.get(url=QB_API_URL["version"], timeout=5) return True - except APIConnectionError: + except httpx.RequestError or httpx.TimeoutException: return False - def check_rss(self, rss_link: str): - pass + async def prefs_init(self, prefs): + await self._client.post(url=QB_API_URL["setPreferences"], data=prefs) - @qb_connect_failed_wait - def prefs_init(self, prefs): - return self._client.app_set_preferences(prefs=prefs) - - @qb_connect_failed_wait - def get_app_prefs(self): - return self._client.app_preferences() - - def add_category(self, category): - return self._client.torrents_createCategory(name=category) - - @qb_connect_failed_wait - def torrents_info(self, status_filter, category, tag=None): - return self._client.torrents_info( - status_filter=status_filter, category=category, tag=tag + async def add_category(self, category): + await self._client.post( + url=QB_API_URL["createCategory"], + data={"category": category}, + timeout=5, ) - def add_torrents(self, torrent_urls, torrent_files, save_path, category): - resp = self._client.torrents_add( - is_paused=False, - urls=torrent_urls, - torrent_files=torrent_files, - save_path=save_path, - category=category, - use_auto_torrent_management=False, + async def torrents_info(self, status_filter, category, tag=None): + data = { + "filter": status_filter, + "category": category, + "tag": tag, + } + torrent_info = await self._client.get( + url=QB_API_URL["info"], + params=data, ) - return resp == "Ok." - - def torrents_delete(self, hash): - return self._client.torrents_delete(delete_files=True, torrent_hashes=hash) + return torrent_info.json() + + async def add(self, torrent_urls, torrent_files, save_path, category): + data = { + "urls": torrent_urls, + "torrent_files": torrent_files, + "save_path": save_path, + "category": category, + "is_paused": False, + "use_auto_torrent_management": False, + } + resp = await self._client.post( + url=QB_API_URL["add"], + data=data, + ) + return resp.status_code == 200 + + async def delete(self, _hash): + data = { + "hashes": _hash, + "deleteFiles": True, + } + resp = await self._client.post( + url=QB_API_URL["delete"], + data=data, + ) + return resp.status_code == 200 + + async def rename(self, torrent_hash, old_path, new_path) -> bool: + data = { + "hash": torrent_hash, + "oldPath": old_path, + "newPath": new_path, + } + resp = await self._client.post( + url=QB_API_URL["renameFile"], + data=data, + ) + return resp.status_code == 200 + + async def move(self, hashes, new_location): + data = { + "hashes": hashes, + "location": new_location, + } + resp = await self._client.post( + url=QB_API_URL["setLocation"], + data=data, + ) + return resp.status_code == 200 + + async def set_category(self, _hash, category): + data = { + "category": category, + "hashes": _hash, + } + resp = await self._client.post( + url=QB_API_URL["setCategory"], + data=data, + ) + return resp.status_code == 200 + + async def add_tag(self, _hash, tag): + data = { + "hashes": _hash, + "tags": tag, + } + resp = await self._client.post( + url=QB_API_URL["addTags"], + data=data, + ) + return resp.status_code == 200 - def torrents_rename_file(self, torrent_hash, old_path, new_path) -> bool: - try: - self._client.torrents_rename_file( - torrent_hash=torrent_hash, old_path=old_path, new_path=new_path + async def __aenter__(self): + self._client = httpx.AsyncClient( + base_url=self.host, + trust_env=self.ssl, + ) + while not await self.check_host(): + logger.warning( + f"[Downloader] Failed to connect to {self.host}, retry in 30 seconds." ) - return True - except Conflict409Error: - logger.debug(f"Conflict409Error: {old_path} >> {new_path}") - return False - - def rss_add_feed(self, url, item_path): - try: - self._client.rss_add_feed(url, item_path) - except Conflict409Error: - logger.warning(f"[Downloader] RSS feed {url} already exists") - - def rss_remove_item(self, item_path): - try: - self._client.rss_remove_item(item_path) - except Conflict409Error: - logger.warning(f"[Downloader] RSS item {item_path} does not exist") - - def rss_get_feeds(self): - return self._client.rss_items() - - def rss_set_rule(self, rule_name, rule_def): - self._client.rss_set_rule(rule_name, rule_def) - - def move_torrent(self, hashes, new_location): - self._client.torrents_set_location(new_location, hashes) - - def get_download_rule(self): - return self._client.rss_rules() - - def get_torrent_path(self, _hash): - return self._client.torrents_info(hashes=_hash)[0].save_path - - def set_category(self, _hash, category): - try: - self._client.torrents_set_category(category, hashes=_hash) - except Conflict409Error: - logger.warning(f"[Downloader] Category {category} does not exist") - self.add_category(category) - self._client.torrents_set_category(category, hashes=_hash) - - def check_connection(self): - return self._client.app_version() - - def remove_rule(self, rule_name): - self._client.rss_remove_rule(rule_name) + await asyncio.sleep(30) + if not await self.auth(): + await self._client.aclose() + logger.error( + f"[Downloader] Downloader authorize error. Please check your username/password." + ) + raise AuthorizationError("Failed to login to qbittorrent.") + return self - def add_tag(self, _hash, tag): - self._client.torrents_add_tags(tags=tag, hashes=_hash) + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.logout() + await self._client.aclose() diff --git a/backend/src/module/downloader/client/tr_downloader.py b/backend/src/module/downloader/client/tr_downloader.py index e69de29bb..d927cf6d5 100644 --- a/backend/src/module/downloader/client/tr_downloader.py +++ b/backend/src/module/downloader/client/tr_downloader.py @@ -0,0 +1,197 @@ +import logging +import httpx +import base64 +import asyncio + +from ..exceptions import AuthorizationError + +logger = logging.getLogger(__name__) + + +class TrDownloader: + def __init__(self, host, username, password, ssl): + self.host = host if "://" in host else "http://" + host + self.username = username + self.password = password + self.ssl = ssl + self.authkey = base64.b64encode( + f"{self.username}:{self.password}".encode() + ).decode() + + self._client = httpx.AsyncClient( + base_url=self.host, + auth=(self.username, self.password), + timeout=5, + ) + + async def __aenter__(self): + self._client = httpx.AsyncClient( + base_url=self.host, + ) + + while not await self.check_host(): + logger.warning( + f"[Downloader] Failed to connect to {self.host}, retry in 30 seconds." + ) + await asyncio.sleep(30) + if not await self.auth(): + await self._client.aclose() + raise AuthorizationError("Failed to login to transmission.") + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.logout() + await self._client.aclose() + + async def auth(self): + # NOTE: Transmission will return 409 when first login + if self.username and self.password: + self._client.headers.update({"Authorization": f"Basic {self.authkey}"}) + + resp = await self._client.post("/transmission/rpc") + + if resp.status_code == 409 and "X-Transmission-Session-Id" in resp.headers: + self._client.headers.update( + {"X-Transmission-Session-Id": resp.headers["X-Transmission-Session-Id"]} + ) + resp = await self._client.post("/transmission/rpc") + elif resp.status_code == 401: + logger.error("Transmission: Authentication failed") + return False + + return resp.status_code == 200 + + def logout(self): + self._client.headers.pop("Authorization") + + async def check_host(self): + try: + await self._client.get("/transmission/web/") + return True + except httpx.RequestError: + return False + + async def add_torrent( + self, download_link=None, torrent_path=None, save_path=None, **kwargs + ): + if not download_link and not torrent_path: + # WARNING: Regard no torrent as success + return True + request_data = { + "method": "torrent-add", + "arguments": {"download-dir": save_path, "paused": False, **kwargs}, + } + + if torrent_path: + try: + with open(torrent_path, "rb") as file: + file_content = file.read() + metainfo = base64.b64encode(file_content).decode() + except FileNotFoundError: + logger.error(f"File not found: {torrent_path}") + return False + + request_data["arguments"].update({"metainfo": metainfo}) + else: + request_data["arguments"].update({"filename": download_link}) + + resp = await self._client.post("/transmission/rpc", json=request_data) + + return resp.status_code == 200 + + async def add(self, torrent_urls, torrent_files, save_path, category): + result = True + for torrent_url in torrent_urls: + result = result and await self.add_torrent( + download_link=torrent_url, save_path=save_path, labels=[category] + ) + + for torrent_file in torrent_files: + result = result and await self.add_torrent( + torrent_path=torrent_file, save_path=save_path + ) + + return result + + async def delete(self, _hash): + request_data = { + "method": "torrent-remove", + "arguments": {"ids": [_hash], "delete-local-data": True}, + } + resp = await self._client.post("/transmission/rpc", json=request_data) + return resp.status_code == 200 + + async def move(self, hashes, new_location): + request_data = { + "method": "torrent-set-location", + "arguments": {"ids": hashes, "location": new_location}, + } + resp = await self._client.post("/transmission/rpc", json=request_data) + return resp.status_code == 200 + + async def rename(self, torrent_hash, old_path, new_path) -> bool: + request_data = { + "method": "torrent-rename-path", + "arguments": {"ids": [torrent_hash], "path": old_path, "name": new_path}, + } + resp = await self._client.post("/transmission/rpc", json=request_data) + return resp.status_code == 200 + + async def torrents_info(self, status_filter, category, tag=None): + KEY_MAP = {"hashString": "hash", "downloadDir": "save_path"} + # Map transmission key to qbittorrent + + request_data = { + "method": "torrent-get", + "arguments": { + "fields": [ + "id", + "name", + "hashString", + "downloadDir", + "status", + "labels", + ], + }, + "format": "object", + } + resp = await self._client.post("/transmission/rpc", json=request_data) + data = resp.json() + torrents_info = data["arguments"].get("torrents") + for torrent_info in torrents_info: + for old_key, new_key in KEY_MAP.items(): + torrent_info[new_key] = torrent_info.pop(old_key) + + torrents_info = self._filter_status(torrents_info, status_filter) + if category: + torrents_info = [ + torrent for torrent in torrents_info if category in torrent["labels"] + ] + # NOTE: To compatible with qbittorrent api we use category as label + + return torrents_info + + async def set_category(self, torrent_hashes, category): + request_data = { + "method": "torrent-set", + "arguments": {"ids": torrent_hashes, "labels": [category]}, + } + + # NOTE: To compatible with qbittorrent api we use category as label + resp = await self._client.post("/transmission/rpc", json=request_data) + return resp.status_code == 200 + + def _filter_status(self, torrents_info, status_filter: str): + """ + Filter torrents by status + Docs: https://github.com/transmission/transmission/blob/main/docs/rpc-spec.md#33-torrent-accessor-torrent-get + """ + if status_filter == "completed": + # We regard torrents queue to seed as completed + return [torrent for torrent in torrents_info if torrent["status"] >= 5] + elif status_filter == "downloading": + return [torrent for torrent in torrents_info if torrent["status"] == 4] + elif status_filter == "inactive": + return [torrent for torrent in torrents_info if torrent["status"] <= 3] + + return torrents_info diff --git a/backend/src/module/downloader/download_client.py b/backend/src/module/downloader/download_client.py index d01d4fa3c..b9b66593a 100644 --- a/backend/src/module/downloader/download_client.py +++ b/backend/src/module/downloader/download_client.py @@ -1,4 +1,5 @@ import logging +import asyncio from module.conf import settings from module.models import Bangumi, Torrent @@ -9,166 +10,76 @@ logger = logging.getLogger(__name__) -class DownloadClient(TorrentPath): - def __init__(self): - super().__init__() - self.client = self.__getClient() - self.authed = False +def getClient(): + # TODO 多下载器支持 + if settings.downloader.type == "qbittorrent": + from .client.qb_downloader import QbDownloader - @staticmethod - def __getClient(): - # TODO 多下载器支持 - type = settings.downloader.type - host = settings.downloader.host - username = settings.downloader.username - password = settings.downloader.password - ssl = settings.downloader.ssl - if type == "qbittorrent": - from .client.qb_downloader import QbDownloader + return QbDownloader + elif type == "transmission": + from .client.tr_downloader import TrDownloader - return QbDownloader(host, username, password, ssl) - else: - logger.error(f"[Downloader] Unsupported downloader type: {type}") - raise Exception(f"Unsupported downloader type: {type}") + return TrDownloader + else: + logger.error(f"[Downloader] Unsupported downloader type: {type}") + raise Exception(f"Unsupported downloader type: {type}") - def __enter__(self): - if not self.authed: - self.auth() - else: - logger.error("[Downloader] Already authed.") - return self - def __exit__(self, exc_type, exc_val, exc_tb): - if self.authed: - self.client.logout() - self.authed = False - - def auth(self): - self.authed = self.client.auth() - if self.authed: - logger.debug("[Downloader] Authed.") - else: - logger.error("[Downloader] Auth failed.") - - def check_host(self): - return self.client.check_host() - - def init_downloader(self): - prefs = { - "rss_auto_downloading_enabled": True, - "rss_max_articles_per_feed": 500, - "rss_processing_enabled": True, - "rss_refresh_interval": 30, - } - self.client.prefs_init(prefs=prefs) - try: - self.client.add_category("BangumiCollection") - except Exception: - logger.debug("[Downloader] Cannot add new category, maybe already exists.") - if settings.downloader.path == "": - prefs = self.client.get_app_prefs() - settings.downloader.path = self._join_path(prefs["save_path"], "Bangumi") - - def set_rule(self, data: Bangumi): - data.rule_name = self._rule_name(data) - data.save_path = self._gen_save_path(data) - rule = { - "enable": True, - "mustContain": data.title_raw, - "mustNotContain": "|".join(data.filter), - "useRegex": True, - "episodeFilter": "", - "smartFilter": False, - "previouslyMatchedEpisodes": [], - "affectedFeeds": data.rss_link, - "ignoreDays": 0, - "lastMatch": "", - "addPaused": False, - "assignedCategory": "Bangumi", - "savePath": data.save_path, - } - self.client.rss_set_rule(rule_name=data.rule_name, rule_def=rule) - data.added = True - logger.info( - f"[Downloader] Add {data.official_title} Season {data.season} to auto download rules." +class DownloadClient(getClient(), TorrentPath): + def __init__(self): + super().__init__( + host=settings.downloader.host, + username=settings.downloader.username, + password=settings.downloader.password, + ssl=settings.downloader.ssl, ) - def set_rules(self, bangumi_info: list[Bangumi]): - logger.debug("[Downloader] Start adding rules.") - for info in bangumi_info: - self.set_rule(info) - logger.debug("[Downloader] Finished.") - - def get_torrent_info(self, category="Bangumi", status_filter="completed", tag=None): - return self.client.torrents_info( + async def get_torrent_info( + self, category="Bangumi", status_filter="completed", tag=None + ): + return await self.torrents_info( status_filter=status_filter, category=category, tag=tag ) - def rename_torrent_file(self, _hash, old_path, new_path) -> bool: + async def rename_torrent_file(self, _hash, old_path, new_path) -> bool: logger.info(f"{old_path} >> {new_path}") - return self.client.torrents_rename_file( + return await self.rename( torrent_hash=_hash, old_path=old_path, new_path=new_path ) - def delete_torrent(self, hashes): - self.client.torrents_delete(hashes) + async def delete_torrent(self, hashes): + resp = await self.delete(hashes) logger.info("[Downloader] Remove torrents.") + return resp - def add_torrent(self, torrent: Torrent | list, bangumi: Bangumi) -> bool: + async def add_torrents(self, torrents: list[Torrent], bangumi: Bangumi) -> bool: if not bangumi.save_path: bangumi.save_path = self._gen_save_path(bangumi) - with RequestContent() as req: - if isinstance(torrent, list): - if len(torrent) == 0: - logger.debug(f"[Downloader] No torrent found: {bangumi.official_title}") - return False - if "magnet" in torrent[0].url: - torrent_url = [t.url for t in torrent] - torrent_file = None - else: - torrent_file = [req.get_content(t.url) for t in torrent] - torrent_url = None + async with RequestContent() as req: + if "magnet" in torrents[0].url: + torrent_url = [t.url for t in torrents] + torrent_file = None else: - if "magnet" in torrent.url: - torrent_url = torrent.url - torrent_file = None - else: - torrent_file = req.get_content(torrent.url) - torrent_url = None - if self.client.add_torrents( + tasks = [] + for t in torrents: + tasks.append(req.get_content(t.url)) + torrent_file = asyncio.gather(*tasks) + torrent_url = None + result = await self.add( torrent_urls=torrent_url, torrent_files=torrent_file, save_path=bangumi.save_path, category="Bangumi", - ): + ) + if result: logger.debug(f"[Downloader] Add torrent: {bangumi.official_title}") return True else: logger.debug(f"[Downloader] Torrent added before: {bangumi.official_title}") return False - def move_torrent(self, hashes, location): - self.client.move_torrent(hashes=hashes, new_location=location) - - # RSS Parts - def add_rss_feed(self, rss_link, item_path="Mikan_RSS"): - self.client.rss_add_feed(url=rss_link, item_path=item_path) - - def remove_rss_feed(self, item_path): - self.client.rss_remove_item(item_path=item_path) - - def get_rss_feed(self): - return self.client.rss_get_feeds() - - def get_download_rules(self): - return self.client.get_download_rule() - - def get_torrent_path(self, hashes): - return self.client.get_torrent_path(hashes) - - def set_category(self, hashes, category): - self.client.set_category(hashes, category) + async def move_torrent(self, hashes, location): + await self.move(hashes=hashes, new_location=location) - def remove_rule(self, rule_name): - self.client.remove_rule(rule_name) - logger.info(f"[Downloader] Delete rule: {rule_name}") + async def set_category(self, hashes, category): + await self.set_category(hashes, category) diff --git a/backend/src/module/downloader/exceptions.py b/backend/src/module/downloader/exceptions.py index 7ec28c73a..c142bdbb4 100644 --- a/backend/src/module/downloader/exceptions.py +++ b/backend/src/module/downloader/exceptions.py @@ -1,2 +1,6 @@ class ConflictError(Exception): pass + + +class AuthorizationError(Exception): + pass diff --git a/backend/src/module/manager/renamer.py b/backend/src/module/manager/renamer.py index 3691d493c..44e530b7a 100644 --- a/backend/src/module/manager/renamer.py +++ b/backend/src/module/manager/renamer.py @@ -5,15 +5,16 @@ from module.downloader import DownloadClient from module.models import EpisodeFile, Notification, SubtitleFile from module.parser import TitleParser +from module.downloader.path import TorrentPath logger = logging.getLogger(__name__) -class Renamer(DownloadClient): +class Renamer(TorrentPath): def __init__(self): super().__init__() self._parser = TitleParser() - self.check_pool = {} + self._check_pool = {} @staticmethod def print_result(torrent_count, rename_count): @@ -25,7 +26,7 @@ def print_result(torrent_count, rename_count): @staticmethod def gen_path( - file_info: EpisodeFile | SubtitleFile, bangumi_name: str, method: str + file_info: EpisodeFile | SubtitleFile, bangumi_name: str, method: str ) -> str: season = f"0{file_info.season}" if file_info.season < 10 else file_info.season episode = ( @@ -48,15 +49,16 @@ def gen_path( logger.error(f"[Renamer] Unknown rename method: {method}") return file_info.media_path - def rename_file( - self, - torrent_name: str, - media_path: str, - bangumi_name: str, - method: str, - season: int, - _hash: str, - **kwargs, + async def rename_file( + self, + torrent_name: str, + media_path: str, + bangumi_name: str, + method: str, + season: int, + _hash: str, + client: DownloadClient, + **kwargs, ): ep = self._parser.torrent_parser( torrent_name=torrent_name, @@ -65,30 +67,33 @@ def rename_file( ) if ep: new_path = self.gen_path(ep, bangumi_name, method=method) - if media_path != new_path: - if new_path not in self.check_pool.keys(): - if self.rename_torrent_file( - _hash=_hash, old_path=media_path, new_path=new_path - ): - return Notification( - official_title=bangumi_name, - season=ep.season, - episode=ep.episode, - ) + success = await self._rename_file_internal( + original_path=media_path, + new_path=new_path, + _hash=_hash, + client=client, + ) + if success: + return Notification( + official_title=bangumi_name, + season=ep.season, + episode=ep.episode, + ) else: logger.warning(f"[Renamer] {media_path} parse failed") if settings.bangumi_manage.remove_bad_torrent: - self.delete_torrent(hashes=_hash) + await client.delete_torrent(hashes=_hash) return None - def rename_collection( - self, - media_list: list[str], - bangumi_name: str, - season: int, - method: str, - _hash: str, - **kwargs, + async def rename_collection( + self, + media_list: list[str], + bangumi_name: str, + season: int, + method: str, + _hash: str, + client: DownloadClient, + **kwargs, ): for media_path in media_list: if self.is_ep(media_path): @@ -98,26 +103,25 @@ def rename_collection( ) if ep: new_path = self.gen_path(ep, bangumi_name, method=method) - if media_path != new_path: - renamed = self.rename_torrent_file( - _hash=_hash, old_path=media_path, new_path=new_path - ) - if not renamed: - logger.warning(f"[Renamer] {media_path} rename failed") - # Delete bad torrent. - if settings.bangumi_manage.remove_bad_torrent: - self.delete_torrent(_hash) - break + success = await self._rename_file_internal( + original_path=media_path, + new_path=new_path, + _hash=_hash, + client=client, + ) + if not success: + break - def rename_subtitles( - self, - subtitle_list: list[str], - torrent_name: str, - bangumi_name: str, - season: int, - method: str, - _hash, - **kwargs, + async def rename_subtitles( + self, + subtitle_list: list[str], + torrent_name: str, + bangumi_name: str, + season: int, + method: str, + _hash, + client: DownloadClient, + **kwargs, ): method = "subtitle_" + method for subtitle_path in subtitle_list: @@ -129,54 +133,75 @@ def rename_subtitles( ) if sub: new_path = self.gen_path(sub, bangumi_name, method=method) - if subtitle_path != new_path: - renamed = self.rename_torrent_file( - _hash=_hash, old_path=subtitle_path, new_path=new_path - ) - if not renamed: - logger.warning(f"[Renamer] {subtitle_path} rename failed") + success = await self._rename_file_internal( + original_path=subtitle_path, + new_path=new_path, + _hash=_hash, + client=client, + ) + if not success: + break - def rename(self) -> list[Notification]: + async def rename(self, client: DownloadClient) -> list[Notification]: # Get torrent info logger.debug("[Renamer] Start rename process.") rename_method = settings.bangumi_manage.rename_method - torrents_info = self.get_torrent_info() + torrents_info = await client.get_torrent_info() renamed_info: list[Notification] = [] for info in torrents_info: - media_list, subtitle_list = self.check_files(info) - bangumi_name, season = self._path_to_bangumi(info.save_path) + media_list, subtitle_list = await client.check_files(info) + bangumi_name, season = await client._path_to_bangumi(info.save_path) kwargs = { "torrent_name": info.name, "bangumi_name": bangumi_name, "method": rename_method, "season": season, "_hash": info.hash, + "client": client, } # Rename single media file if len(media_list) == 1: - notify_info = self.rename_file(media_path=media_list[0], **kwargs) + notify_info = await self.rename_file(media_path=media_list[0], **kwargs) if notify_info: renamed_info.append(notify_info) # Rename subtitle file if len(subtitle_list) > 0: - self.rename_subtitles(subtitle_list=subtitle_list, **kwargs) + await self.rename_subtitles(subtitle_list=subtitle_list, **kwargs) # Rename collection elif len(media_list) > 1: logger.info("[Renamer] Start rename collection") - self.rename_collection(media_list=media_list, **kwargs) + await self.rename_collection(media_list=media_list, **kwargs) if len(subtitle_list) > 0: - self.rename_subtitles(subtitle_list=subtitle_list, **kwargs) - self.set_category(info.hash, "BangumiCollection") + await self.rename_subtitles(subtitle_list=subtitle_list, **kwargs) + await client.set_category(info.hash, "BangumiCollection") else: logger.warning(f"[Renamer] {info.name} has no media file") logger.debug("[Renamer] Rename process finished.") return renamed_info - def compare_ep_version(self, torrent_name: str, torrent_hash: str): + async def compare_ep_version(self, torrent_name: str, torrent_hash: str, client: DownloadClient): if re.search(r"v\d.", torrent_name): pass else: - self.delete_torrent(hashes=torrent_hash) + await client.delete_torrent(hashes=torrent_hash) + + @staticmethod + async def _rename_file_internal( + original_path: str, + new_path: str, + _hash: str, + client: DownloadClient, + ) -> bool: + if original_path != new_path: + renamed = await client.rename_torrent_file( + _hash=_hash, old_path=original_path, new_path=new_path + ) + if not renamed: + logger.warning(f"[Renamer] {original_path} rename failed") + if settings.bangumi_manage.remove_bad_torrent: + await client.delete_torrent(_hash) + return False + return True if __name__ == "__main__": diff --git a/backend/src/module/models/config.py b/backend/src/module/models/config.py index 49044fa3d..fa78cd236 100644 --- a/backend/src/module/models/config.py +++ b/backend/src/module/models/config.py @@ -20,6 +20,17 @@ class Downloader(BaseModel): path: str = Field("/downloads/Bangumi", description="Downloader path") ssl: bool = Field(False, description="Downloader ssl") + +class QbDownloader(Downloader): + type: str = Field("qbittorrent", description="Downloader type") + host_: str = Field("172.17.0.1:8080", alias="host", description="Downloader host") + username_: str = Field("admin", alias="username", description="Downloader username") + password_: str = Field( + "adminadmin", alias="password", description="Downloader password" + ) + path: str = Field("/downloads/Bangumi", description="Downloader path") + ssl: bool = Field(False, description="Downloader ssl") + @property def host(self): return expandvars(self.host_) @@ -33,6 +44,15 @@ def password(self): return expandvars(self.password_) +class TrDownloader(Downloader): + type: str = Field("transmission", description="Downloader type") + host_: str = Field("172.17.0.1:9091", alias="host", description="Downloader host") + username_: str = Field("admin", alias="username", description="Downloader username") + password_: str = Field("admin", alias="password", description="Downloader password") + path: str = Field("/downloads/Bangumi", description="Downloader path") + ssl: bool = Field(False, description="Downloader ssl") + + class RSSParser(BaseModel): enable: bool = Field(True, description="Enable RSS parser") filter: list[str] = Field(["720", r"\d+-\d"], description="Filter") diff --git a/backend/src/module/models/response.py b/backend/src/module/models/response.py index 9bd352724..81b7938b2 100644 --- a/backend/src/module/models/response.py +++ b/backend/src/module/models/response.py @@ -11,4 +11,4 @@ class ResponseModel(BaseModel): class APIResponse(BaseModel): status: bool = Field(..., example=True) msg_en: str = Field(..., example="Success") - msg_zh: str = Field(..., example="成功") \ No newline at end of file + msg_zh: str = Field(..., example="成功") diff --git a/backend/src/module/network/proxy.py b/backend/src/module/network/proxy.py new file mode 100644 index 000000000..6ac163e45 --- /dev/null +++ b/backend/src/module/network/proxy.py @@ -0,0 +1,20 @@ +from module.conf import settings + + +@property +def set_proxy(): + auth = ( + f"{settings.proxy.username}:{settings.proxy.password}@" + if settings.proxy.username + else "" + ) + if "http" in settings.proxy.type: + proxy = ( + f"{settings.proxy.type}://{auth}{settings.proxy.host}:{settings.proxy.port}" + ) + elif settings.proxy.type == "socks5": + proxy = f"socks5://{auth}{settings.proxy.host}:{settings.proxy.port}" + else: + proxy = None + logger.error(f"[Network] Unsupported proxy type: {settings.proxy.type}") + return proxy diff --git a/backend/src/module/network/request_contents.py b/backend/src/module/network/request_contents.py index 05abca023..13909e6f4 100644 --- a/backend/src/module/network/request_contents.py +++ b/backend/src/module/network/request_contents.py @@ -12,16 +12,17 @@ class RequestContent(RequestURL): - def get_torrents( + async def get_torrents( self, _url: str, - _filter: str = "|".join(settings.rss_parser.filter), + _filter: str = None, limit: int = None, retry: int = 3, ) -> list[Torrent]: - soup = self.get_xml(_url, retry) - if soup: - torrent_titles, torrent_urls, torrent_homepage = rss_parser(soup) + feeds = await self.get_xml(_url, retry) + _filter = _filter if _filter else "|".join(settings.rss_parser.filter) + if feeds: + torrent_titles, torrent_urls, torrent_homepage = rss_parser(feeds) torrents: list[Torrent] = [] for _title, torrent_url, homepage in zip( torrent_titles, torrent_urls, torrent_homepage @@ -30,46 +31,39 @@ def get_torrents( torrents.append( Torrent(name=_title, url=torrent_url, homepage=homepage) ) - if isinstance(limit, int): - if len(torrents) >= limit: - break - return torrents + return torrents if limit is None else torrents[:limit] else: - logger.warning(f"[Network] Failed to get torrents: {_url}") + logger.error(f"[Network] Torrents list is empty: {_url}") return [] - def get_xml(self, _url, retry: int = 3) -> xml.etree.ElementTree.Element: - req = self.get_url(_url, retry) + async def get_xml(self, _url, retry: int = 3) -> xml.etree.ElementTree.Element: + req = await self.get_url(_url, retry) if req: return xml.etree.ElementTree.fromstring(req.text) # API JSON - def get_json(self, _url) -> dict: - req = self.get_url(_url) + async def get_json(self, _url) -> dict: + req = await self.get_url(_url) if req: return req.json() - def post_json(self, _url, data: dict) -> dict: - return self.post_url(_url, data).json() - - def post_data(self, _url, data: dict) -> dict: - return self.post_url(_url, data) + async def post_data(self, _url, data: dict, files: dict[str, bytes]) -> dict: + return await self.post_url(_url, data, files) - def post_files(self, _url, data: dict, files: dict) -> dict: - return self.post_form(_url, data, files) - - def get_html(self, _url): - return self.get_url(_url).text + async def get_html(self, _url): + req = await self.get_url(_url) + if req: + return req.text - def get_content(self, _url): - req = self.get_url(_url) + async def get_content(self, _url): + req = await self.get_url(_url) if req: return req.content - def check_connection(self, _url): - return self.check_url(_url) + async def check_connection(self, _url): + return await self.check_url(_url) - def get_rss_title(self, _url): - soup = self.get_xml(_url) + async def get_rss_title(self, _url): + soup = await self.get_xml(_url) if soup: return soup.find("./channel/title").text diff --git a/backend/src/module/network/request_url.py b/backend/src/module/network/request_url.py index 0b85e77e0..238be73a0 100644 --- a/backend/src/module/network/request_url.py +++ b/backend/src/module/network/request_url.py @@ -1,9 +1,9 @@ +import asyncio import logging -import socket -import time -import requests -import socks +import httpx + +from .proxy import set_proxy from module.conf import settings @@ -13,112 +13,60 @@ class RequestURL: def __init__(self): self.header = {"user-agent": "Mozilla/5.0", "Accept": "application/xml"} - self._socks5_proxy = False + self.proxy = set_proxy if settings.proxy.enable else None - def get_url(self, url, retry=3): - try_time = 0 - while True: + async def get_url(self, url, retry=3): + for _ in range(retry): try: - req = self.session.get(url=url, headers=self.header, timeout=5) - logger.debug(f"[Network] Successfully connected to {url}. Status: {req.status_code}") - req.raise_for_status() + req = await self.client.get(url=url) return req - except requests.RequestException: + except httpx.RequestError: + logger.debug(f"[Network] Cannot connect to {url}. Wait for 5 seconds.") + except httpx.TimeoutException: logger.debug( - f"[Network] Cannot connect to {url}. Wait for 5 seconds." + f"[Network] Timeout. Cannot connect to {url}. Wait for 5 seconds." ) - try_time += 1 - if try_time >= retry: - break - time.sleep(5) except Exception as e: logger.debug(e) + logger.error(f"[Network] Cannot connect to {url}") break - logger.error(f"[Network] Unable to connect to {url}, Please check your network settings") - return None + await asyncio.sleep(5) - def post_url(self, url: str, data: dict, retry=3): - try_time = 0 - while True: + async def post_url( + self, url: str, data: dict, files: dict[str, bytes] = None, retry: int = 3 + ): + for _ in range(retry): try: - req = self.session.post( - url=url, headers=self.header, data=data, timeout=5 - ) - req.raise_for_status() + req = await self.client.post(url=url, data=data, files=files) return req - except requests.RequestException: - logger.warning( - f"[Network] Cannot connect to {url}. Wait for 5 seconds." + except httpx.RequestError: + logger.debug(f"[Network] Cannot connect to {url}. Wait for 5 seconds.") + except httpx.TimeoutException: + logger.debug( + f"[Network] Timeout. Cannot connect to {url}. Wait for 5 seconds." ) - try_time += 1 - if try_time >= retry: - break - time.sleep(5) except Exception as e: logger.debug(e) + logger.error(f"[Network] Cannot connect to {url}") break - logger.error(f"[Network] Failed connecting to {url}") - logger.warning("[Network] Please check DNS/Connection settings") - return None + await asyncio.sleep(5) - def check_url(self, url: str): + async def check_url(self, url: str): if "://" not in url: url = f"http://{url}" try: - req = requests.head(url=url, headers=self.header, timeout=5) + req = await self.client.get(url=url) req.raise_for_status() return True - except requests.RequestException: + except httpx.RequestError: logger.debug(f"[Network] Cannot connect to {url}.") return False - def post_form(self, url: str, data: dict, files): - try: - req = self.session.post( - url=url, headers=self.header, data=data, files=files, timeout=5 - ) - req.raise_for_status() - return req - except requests.RequestException: - logger.warning(f"[Network] Cannot connect to {url}.") - return None - - def __enter__(self): - self.session = requests.Session() - if settings.proxy.enable: - if "http" in settings.proxy.type: - if settings.proxy.username: - username=settings.proxy.username - password=settings.proxy.password - url = f"http://{username}:{password}@{settings.proxy.host}:{settings.proxy.port}" - self.session.proxies = { - "http": url, - "https": url, - } - else: - url = f"http://{settings.proxy.host}:{settings.proxy.port}" - self.session.proxies = { - "http": url, - "https": url, - } - elif settings.proxy.type == "socks5": - self._socks5_proxy = True - socks.set_default_proxy( - socks.SOCKS5, - addr=settings.proxy.host, - port=settings.proxy.port, - rdns=True, - username=settings.proxy.username, - password=settings.proxy.password, - ) - socket.socket = socks.socksocket - else: - logger.error(f"[Network] Unsupported proxy type: {settings.proxy.type}") + async def __aenter__(self): + self.client = httpx.AsyncClient( + http2=True, proxies=self.proxy, headers=self.header, timeout=5 + ) return self - def __exit__(self, exc_type, exc_val, exc_tb): - if self._socks5_proxy: - socks.set_default_proxy() - socket.socket = socks.socksocket - self._socks5_proxy = False - self.session.close() + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.client.aclose() diff --git a/backend/src/module/notification/plugin/bark.py b/backend/src/module/notification/plugin/bark.py index 0574db1e6..e758ef3a7 100644 --- a/backend/src/module/notification/plugin/bark.py +++ b/backend/src/module/notification/plugin/bark.py @@ -21,7 +21,12 @@ def gen_message(notify: Notification) -> str: def post_msg(self, notify: Notification) -> bool: text = self.gen_message(notify) - data = {"title": notify.official_title, "body": text, "icon": notify.poster_path, "device_key": self.token} + data = { + "title": notify.official_title, + "body": text, + "icon": notify.poster_path, + "device_key": self.token, + } resp = self.post_data(self.notification_url, data) logger.debug(f"Bark notification: {resp.status_code}") return resp.status_code == 200 diff --git a/backend/src/module/parser/analyser/tmdb_parser.py b/backend/src/module/parser/analyser/tmdb_parser.py index 3b930dbb0..534a7e2d8 100644 --- a/backend/src/module/parser/analyser/tmdb_parser.py +++ b/backend/src/module/parser/analyser/tmdb_parser.py @@ -31,13 +31,12 @@ def info_url(e, key): return f"{TMDB_URL}/3/tv/{e}?api_key={TMDB_API}&language={LANGUAGE[key]}" -def is_animation(tv_id, language) -> bool: +async def is_animation(tv_id, language, req) -> bool: url_info = info_url(tv_id, language) - with RequestContent() as req: - type_id = req.get_json(url_info)["genres"] - for type in type_id: - if type.get("id") == 16: - return True + type_ids = await req.get_json(url_info) + for type in type_ids["genres"]: + if type.get("id") == 16: + return True return False @@ -56,10 +55,11 @@ def get_season(seasons: list) -> tuple[int, str]: return len(ss), ss[-1].get("poster_path") -def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: - with RequestContent() as req: +async def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: + async with RequestContent() as req: url = search_url(title) - contents = req.get_json(url).get("results") + json_contents = await req.get_json(url) + contents = json_contents.get("results") if contents.__len__() == 0: url = search_url(title.replace(" ", "")) contents = req.get_json(url).get("results") @@ -67,10 +67,10 @@ def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: if contents: for content in contents: id = content["id"] - if is_animation(id, language): + if await is_animation(id, language, req): break url_info = info_url(id, language) - info_content = req.get_json(url_info) + info_content = await req.get_json(url_info) season = [ { "season": s.get("name"), @@ -87,7 +87,9 @@ def tmdb_parser(title, language, test: bool = False) -> TMDBInfo | None: year_number = info_content.get("first_air_date").split("-")[0] if poster_path: if not test: - img = req.get_content(f"https://image.tmdb.org/t/p/w780{poster_path}") + img = await req.get_content( + f"https://image.tmdb.org/t/p/w780{poster_path}" + ) poster_link = save_image(img, "jpg") else: poster_link = "https://image.tmdb.org/t/p/w780" + poster_path diff --git a/backend/src/module/rss/__init__.py b/backend/src/module/rss/__init__.py index 70406ee39..01b68d04c 100644 --- a/backend/src/module/rss/__init__.py +++ b/backend/src/module/rss/__init__.py @@ -1,2 +1,3 @@ from .analyser import RSSAnalyser from .engine import RSSEngine +from .manager import RSSManager diff --git a/backend/src/module/rss/analyser.py b/backend/src/module/rss/analyser.py index 457098e69..efc984013 100644 --- a/backend/src/module/rss/analyser.py +++ b/backend/src/module/rss/analyser.py @@ -99,4 +99,3 @@ def link_to_data(self, rss: RSSItem) -> Bangumi | ResponseModel: msg_en="Cannot parse this link.", msg_zh="无法解析此链接。", ) - diff --git a/backend/src/module/rss/engine.py b/backend/src/module/rss/engine.py index 244a6ab55..e2c1aba9e 100644 --- a/backend/src/module/rss/engine.py +++ b/backend/src/module/rss/engine.py @@ -1,106 +1,53 @@ import logging +import asyncio import re -from typing import Optional +from typing import Optional, Callable from module.database import Database, engine from module.downloader import DownloadClient from module.models import Bangumi, ResponseModel, RSSItem, Torrent from module.network import RequestContent +from module.conf import settings logger = logging.getLogger(__name__) -class RSSEngine(Database): - def __init__(self, _engine=engine): - super().__init__(_engine) +class RSSEngine: + def __init__(self): self._to_refresh = False + self.db_status = False + + async def rss_poller(self, callback: Callable = None): + with Database() as database: + rss_items = database.rss.search_active() + if rss_items: + tasks = [] + for item in rss_items: + tasks.append(self.pull_rss(item, database, callback)) + await asyncio.gather(*tasks) @staticmethod - def _get_torrents(rss: RSSItem) -> list[Torrent]: - with RequestContent() as req: - torrents = req.get_torrents(rss.url) + async def _get_torrents(rss: RSSItem) -> list[Torrent]: + async with RequestContent() as req: + torrents = await req.get_torrents(rss.url) # Add RSS ID for torrent in torrents: torrent.rss_id = rss.id - return torrents - - def get_rss_torrents(self, rss_id: int) -> list[Torrent]: - rss = self.rss.search_id(rss_id) - if rss: - return self.torrent.search_rss(rss_id) - else: - return [] - - def add_rss( - self, - rss_link: str, - name: str | None = None, - aggregate: bool = True, - parser: str = "mikan", - ): - if not name: - with RequestContent() as req: - name = req.get_rss_title(rss_link) - if not name: - return ResponseModel( - status=False, - status_code=406, - msg_en="Failed to get RSS title.", - msg_zh="无法获取 RSS 标题。", - ) - rss_data = RSSItem(name=name, url=rss_link, aggregate=aggregate, parser=parser) - if self.rss.add(rss_data): - return ResponseModel( - status=True, - status_code=200, - msg_en="RSS added successfully.", - msg_zh="RSS 添加成功。", - ) - else: - return ResponseModel( - status=False, - status_code=406, - msg_en="RSS added failed.", - msg_zh="RSS 添加失败。", - ) - - def disable_list(self, rss_id_list: list[int]): - for rss_id in rss_id_list: - self.rss.disable(rss_id) - return ResponseModel( - status=True, - status_code=200, - msg_en="Disable RSS successfully.", - msg_zh="禁用 RSS 成功。", - ) + return torrents - def enable_list(self, rss_id_list: list[int]): - for rss_id in rss_id_list: - self.rss.enable(rss_id) - return ResponseModel( - status=True, - status_code=200, - msg_en="Enable RSS successfully.", - msg_zh="启用 RSS 成功。", - ) - - def delete_list(self, rss_id_list: list[int]): - for rss_id in rss_id_list: - self.rss.delete(rss_id) - return ResponseModel( - status=True, - status_code=200, - msg_en="Delete RSS successfully.", - msg_zh="删除 RSS 成功。", - ) - - def pull_rss(self, rss_item: RSSItem) -> list[Torrent]: - torrents = self._get_torrents(rss_item) - new_torrents = self.torrent.check_new(torrents) + async def pull_rss( + self, rss_item: RSSItem, database: Database = None, callback: Callable = None, **kwargs + ) -> list[Torrent]: + torrents = await self._get_torrents(rss_item) + new_torrents = database.torrent.check_new(torrents) + if callback: + await callback(rss_item, new_torrents, **kwargs) + database.torrent.add_all(new_torrents) return new_torrents - def match_torrent(self, torrent: Torrent) -> Optional[Bangumi]: - matched: Bangumi = self.bangumi.match_torrent(torrent.name) + @staticmethod + def match_torrent(torrent: Torrent, database: Database) -> Optional[Bangumi]: + matched: Bangumi = database.bangumi.match_torrent(torrent.name) if matched: if matched.filter == "": return matched @@ -110,36 +57,49 @@ def match_torrent(self, torrent: Torrent) -> Optional[Bangumi]: return matched return None - def refresh_rss(self, client: DownloadClient, rss_id: Optional[int] = None): + async def refresh_rss( + self, + client: DownloadClient, + database: Database = None, + rss_id: Optional[int] = None, + callback: Callable = None, + ): + # Connect to Database if not connected + if not database: + database = self.__connect_database() + self.db_status = True # Get All RSS Items if not rss_id: - rss_items: list[RSSItem] = self.rss.search_active() + rss_items: list[RSSItem] = database.rss.search_active() else: - rss_item = self.rss.search_id(rss_id) - rss_items = [rss_item] if rss_item else [] + rss_items = [database.rss.search_id(rss_id)] # From RSS Items, get all torrents logger.debug(f"[Engine] Get {len(rss_items)} RSS items") + tasks = [] for rss_item in rss_items: - new_torrents = self.pull_rss(rss_item) - # Get all enabled bangumi data - for torrent in new_torrents: - matched_data = self.match_torrent(torrent) - if matched_data: - if client.add_torrent(torrent, matched_data): - logger.debug(f"[Engine] Add torrent {torrent.name} to client") - torrent.downloaded = True - # Add all torrents to database - self.torrent.add_all(new_torrents) + tasks.append( + self.pull_rss( + rss_item=rss_item, + database=database, + callback=callback, + client=client, + ) + ) + await asyncio.gather(*tasks) + # Close Database if not connected + if self.db_status: + database.close() - def download_bangumi(self, bangumi: Bangumi): - with RequestContent() as req: - torrents = req.get_torrents( + @staticmethod + async def download_bangumi(bangumi: Bangumi, database: Database): + async with RequestContent() as req: + torrents = await req.get_torrents( bangumi.rss_link, bangumi.filter.replace(",", "|") ) if torrents: - with DownloadClient() as client: - client.add_torrent(torrents, bangumi) - self.torrent.add_all(torrents) + async with DownloadClient() as client: + await client.add_torrents(torrents, bangumi) + database.torrent.add_all(torrents) return ResponseModel( status=True, status_code=200, @@ -153,3 +113,9 @@ def download_bangumi(self, bangumi: Bangumi): msg_en=f"[Engine] Download {bangumi.official_title} failed.", msg_zh=f"[Engine] 下载 {bangumi.official_title} 失败。", ) + + @staticmethod + def __connect_database(): + return Database(engine) + + diff --git a/backend/src/module/rss/manager.py b/backend/src/module/rss/manager.py new file mode 100644 index 000000000..a65b8458d --- /dev/null +++ b/backend/src/module/rss/manager.py @@ -0,0 +1,80 @@ +import re + +from module.database import Database, engine +from module.network import RequestContent +from module.models import ResponseModel, RSSItem, Torrent + + +class RSSManager(Database): + def __init__(self, _engine=engine): + super().__init__(engine=_engine) + + async def add_rss( + self, + rss_link: str, + name: str | None = None, + aggregate: bool = True, + parser: str = "mikan", + ): + if not name: + async with RequestContent() as req: + name = await req.get_rss_title(rss_link) + if not name: + return ResponseModel( + status=False, + status_code=406, + msg_en="Failed to get RSS title.", + msg_zh="无法获取 RSS 标题。", + ) + rss_data = RSSItem(name=name, url=rss_link, aggregate=aggregate, parser=parser) + if self.rss.add(rss_data): + return ResponseModel( + status=True, + status_code=200, + msg_en="RSS added successfully.", + msg_zh="RSS 添加成功。", + ) + else: + return ResponseModel( + status=False, + status_code=406, + msg_en="RSS added failed.", + msg_zh="RSS 添加失败。", + ) + + def disable_list(self, rss_id_list: list[int]): + for rss_id in rss_id_list: + self.rss.disable(rss_id) + return ResponseModel( + status=True, + status_code=200, + msg_en="Disable RSS successfully.", + msg_zh="禁用 RSS 成功。", + ) + + def enable_list(self, rss_id_list: list[int]): + for rss_id in rss_id_list: + self.rss.enable(rss_id) + return ResponseModel( + status=True, + status_code=200, + msg_en="Enable RSS successfully.", + msg_zh="启用 RSS 成功。", + ) + + def delete_list(self, rss_id_list: list[int]): + for rss_id in rss_id_list: + self.rss.delete(rss_id) + return ResponseModel( + status=True, + status_code=200, + msg_en="Delete RSS successfully.", + msg_zh="删除 RSS 成功。", + ) + + def get_rss_torrents(self, rss_id: int) -> list[Torrent]: + rss = self.rss.search_id(rss_id) + if rss: + return self.torrent.search_rss(rss_id) + else: + return [] diff --git a/backend/src/module/rss/pool.py b/backend/src/module/rss/pool.py new file mode 100644 index 000000000..b66f9d00e --- /dev/null +++ b/backend/src/module/rss/pool.py @@ -0,0 +1,23 @@ +import asyncio +from typing import Callable + +from module.models import RSSItem, Torrent +from module.network import RequestContent +from module.conf import settings + + +async def rss_checker(rss: list[RSSItem], callback: Callable[[list[Torrent]], None]): + torrent_pool = [] + torrent_name_pool = [] + while 1: + async with RequestContent() as req: + for item in rss: + torrents = await req.get_torrents(item.url) + for torrent in torrents: + if torrent.name not in torrent_name_pool: + torrent_pool.append(torrent) + torrent_name_pool.append(torrent.name) + if torrent_pool: + callback(torrent_pool) + torrent_pool.clear() + await asyncio.sleep(settings.rss.interval) diff --git a/backend/src/module/update/cross_version.py b/backend/src/module/update/cross_version.py index 701241b2b..28a8b24c3 100644 --- a/backend/src/module/update/cross_version.py +++ b/backend/src/module/update/cross_version.py @@ -2,13 +2,13 @@ from urllib3.util import parse_url -from module.rss import RSSEngine +from module.rss import RSSManager from module.utils import save_image from module.network import RequestContent -def from_30_to_31(): - with RSSEngine() as db: +async def from_30_to_31(): + async with RSSManager() as db: db.migrate() # Update poster link bangumis = db.bangumi.search_all() @@ -29,18 +29,17 @@ def from_30_to_31(): aggregate = True else: aggregate = False - db.add_rss(rss_link=rss, aggregate=aggregate) + await db.add_rss(rss_link=rss, aggregate=aggregate) -def cache_image(): - with RSSEngine() as db, RequestContent() as req: +async def cache_image(): + async with RSSManager() as db, RequestContent() as req: bangumis = db.bangumi.search_all() for bangumi in bangumis: if bangumi.poster_link: # Hash local path - img = req.get_content(bangumi.poster_link) + img = await req.get_content(bangumi.poster_link) suffix = bangumi.poster_link.split(".")[-1] img_path = save_image(img, suffix) bangumi.poster_link = img_path db.bangumi.update_all(bangumis) - diff --git a/backend/src/module/utils/__init__.py b/backend/src/module/utils/__init__.py index a95499d60..d885cf723 100644 --- a/backend/src/module/utils/__init__.py +++ b/backend/src/module/utils/__init__.py @@ -1 +1 @@ -from .cache_image import save_image, load_image \ No newline at end of file +from .cache_image import save_image, load_image diff --git a/backend/src/test/test_database.py b/backend/src/test/test_database.py index 5ee7ad93b..8efa515a4 100644 --- a/backend/src/test/test_database.py +++ b/backend/src/test/test_database.py @@ -46,7 +46,8 @@ def test_bangumi_database(): # match torrent result = db.bangumi.match_torrent( - "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]" + "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]", + "test", ) assert result.official_title == "无职转生,到了异世界就拿出真本事II" diff --git a/backend/src/test/test_path_parser.py b/backend/src/test/test_path_parser.py index 9e14ea174..d06455f01 100644 --- a/backend/src/test/test_path_parser.py +++ b/backend/src/test/test_path_parser.py @@ -4,9 +4,8 @@ def test_path_to_bangumi(): # Test for unix-like path from module.downloader.path import TorrentPath + path = "Downloads/Bangumi/Kono Subarashii Sekai ni Shukufuku wo!/Season 2/" bangumi_name, season = TorrentPath()._path_to_bangumi(path) assert bangumi_name == "Kono Subarashii Sekai ni Shukufuku wo!" assert season == 2 - - diff --git a/backend/src/test/test_raw_parser.py b/backend/src/test/test_raw_parser.py index 574f38a58..85903ca71 100644 --- a/backend/src/test/test_raw_parser.py +++ b/backend/src/test/test_raw_parser.py @@ -70,7 +70,9 @@ def test_raw_parser(): assert info.episode == 5 assert info.season == 1 - content = "【喵萌奶茶屋】★07月新番★[银砂糖师与黑妖精 ~ Sugar Apple Fairy Tale ~][13][1080p][简日双语][招募翻译]" + content = ( + "【喵萌奶茶屋】★07月新番★[银砂糖师与黑妖精 ~ Sugar Apple Fairy Tale ~][13][1080p][简日双语][招募翻译]" + ) info = raw_parser(content) assert info.group == "喵萌奶茶屋" assert info.title_zh == "银砂糖师与黑妖精" @@ -79,13 +81,12 @@ def test_raw_parser(): assert info.episode == 13 assert info.season == 1 - content = "[ANi] 16bit 的感动 ANOTHER LAYER - 01 [1080P][Baha][WEB-DL][AAC AVC][CHT][MP4]" + content = ( + "[ANi] 16bit 的感动 ANOTHER LAYER - 01 [1080P][Baha][WEB-DL][AAC AVC][CHT][MP4]" + ) info = raw_parser(content) assert info.group == "ANi" assert info.title_zh == "16bit 的感动 ANOTHER LAYER" assert info.resolution == "1080P" assert info.episode == 1 assert info.season == 1 - - - diff --git a/backend/src/test/test_rss_engine.py b/backend/src/test/test_rss_engine.py index cda69f6ed..b16022a4b 100644 --- a/backend/src/test/test_rss_engine.py +++ b/backend/src/test/test_rss_engine.py @@ -1,18 +1,24 @@ -from module.rss.engine import RSSEngine +import pytest +from module.rss import RSSEngine, RSSManager from .test_database import engine as e -def test_rss_engine(): - with RSSEngine(e) as engine: +@pytest.mark.asyncio +async def test_rss_engine(): + engine = RSSEngine() + with RSSManager(e) as manager: rss_link = "https://mikanani.me/RSS/Bangumi?bangumiId=2353&subgroupid=552" - engine.add_rss(rss_link, aggregate=False) + resp = await manager.add_rss(rss_link, aggregate=False) + assert resp.status - result = engine.rss.search_active() + result = manager.rss.search_active() assert result[1].name == "Mikan Project - 无职转生~到了异世界就拿出真本事~" - new_torrents = engine.pull_rss(result[1]) + new_torrents = await engine.pull_rss(result[1], database=manager) torrent = new_torrents[0] - assert torrent.name == "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]" - + assert ( + torrent.name + == "[Lilith-Raws] 无职转生,到了异世界就拿出真本事 / Mushoku Tensei - 11 [Baha][WEB-DL][1080p][AVC AAC][CHT][MP4]" + ) diff --git a/backend/src/test/test_tmdb.py b/backend/src/test/test_tmdb.py index 03724da43..90b694e9e 100644 --- a/backend/src/test/test_tmdb.py +++ b/backend/src/test/test_tmdb.py @@ -1,12 +1,15 @@ +import pytest + from module.parser.analyser.tmdb_parser import tmdb_parser -def test_tmdb_parser(): +@pytest.mark.asyncio +async def test_tmdb_parser(): bangumi_title = "海盗战记" bangumi_year = "2019" bangumi_season = 2 - tmdb_info = tmdb_parser(bangumi_title, "zh", test=True) + tmdb_info = await tmdb_parser(bangumi_title, "zh", test=True) assert tmdb_info.title == "冰海战记" assert tmdb_info.year == bangumi_year diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index 4521b494f..000000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "pythonPath": "/opt/homebrew/Caskroom/miniforge/base/envs/auto_bangumi/bin/python", - "root": "backend/src", - "venvPath": "/opt/homebrew/Caskroom/miniforge/base/envs", - "venv": "auto_bangumi", - "typeCheckingMode": "basic", - "reportMissingImports": true -}