Skip to content

Commit 696cf16

Browse files
authored
Merge pull request #96 from shinonomeow/shino_aio
update pydantic to v2
2 parents a3c2851 + 709633d commit 696cf16

4 files changed

Lines changed: 38 additions & 146 deletions

File tree

backend/src/module/conf/config.py

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,12 @@
3434
CONFIG_PATH = (CONFIG_ROOT / CONFIX_NAME).resolve()
3535

3636

37-
def model_dump(
38-
model: BaseModel,
39-
include: set[str] | None = None,
40-
exclude: set[str] | None = None,
41-
by_alias: bool = False,
42-
exclude_unset: bool = False,
43-
exclude_defaults: bool = False,
44-
exclude_none: bool = False,
45-
) -> dict[str, Any]:
46-
return model.dict(
47-
include=include,
48-
exclude=exclude,
49-
by_alias=by_alias,
50-
exclude_unset=exclude_unset,
51-
exclude_defaults=exclude_defaults,
52-
exclude_none=exclude_none,
53-
)
54-
55-
5637
# 判断给定的 data 的 key 是否在 setting 中
5738
def check_config_key(
5839
data: dict | BaseModel, updated_data: BaseModel, config_name: str
5940
) -> bool:
6041
if isinstance(data, BaseModel):
6142
data = data.model_dump()
62-
# data = data.dict()
6343

6444
updated_data = updated_data.model_dump()
6545
for key in updated_data.keys():
@@ -70,7 +50,7 @@ def check_config_key(
7050

7151
def get_plugin_config(config: T_BaseModel, config_name: str) -> T_BaseModel:
7252
"""从全局配置获取当前插件需要的配置项,更新 data 中的缺失项。"""
73-
global_data = model_dump(settings)
53+
global_data = settings.model_dump()
7454
data = global_data.get(config_name, {})
7555
# data 可能是 dict 和 BaseModel 的实例
7656
# 尝试从配置文件中读取配置, 如果不足则使用默认配置
@@ -85,8 +65,7 @@ def get_plugin_config(config: T_BaseModel, config_name: str) -> T_BaseModel:
8565

8666
def type_validate_python(type_: BaseModel, data: Any) -> BaseModel:
8767
"""Validate data with given type, checking required fields exist."""
88-
validated_data = type_.__class__.validate(data)
89-
68+
validated_data = type_.__class__.model_validate(data)
9069
return validated_data
9170

9271

@@ -99,11 +78,9 @@ def update_config(base_config: BaseModel | dict, data: dict):
9978
# # 获取 baseconfig 的当前字段数据
10079
if isinstance(base_config, BaseModel):
10180
updated_data = base_config.model_dump()
102-
# updated_data = base_config.dict()
10381
updated_data = deep_update(updated_data, data)
104-
updated_instance = base_config.__class__.validate(updated_data)
82+
updated_instance = base_config.__class__.model_validate(updated_data)
10583
updata_dict = updated_instance.model_dump()
106-
# updata_dict = updated_instance.dict()
10784
else:
10885
# 当 baseconfig 是 dict 类型时, 直接更新
10986
updated_data = base_config
@@ -170,7 +147,7 @@ def __load_from_env(self):
170147
else:
171148
attr_name = attr[0] if isinstance(attr, tuple) else attr
172149
config_dict[key][attr_name] = self.__val_from_env(env, attr)
173-
config_obj = Config.validate(config_dict)
150+
config_obj = Config.model_validate(config_dict)
174151
self.__dict__.update(config_obj.__dict__)
175152
logger.info("Config loaded from env")
176153

@@ -182,9 +159,5 @@ def __val_from_env(env: str, attr: tuple[str, Callable[..., Any]] | str):
182159
else:
183160
return os.environ[env]
184161

185-
@property
186-
def group_rules(self):
187-
return self.__dict__["group_rules"]
188-
189162

190163
settings = Settings()

backend/src/module/conf/log.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,6 @@ def setup_logger(level: int = logging.INFO, reset: bool = False):
5959
for logger_name in loggers_to_silence:
6060
logger = logging.getLogger(logger_name)
6161
logger.setLevel(logging.WARNING)
62-
# logger.propagate = False
63-
# logger.handlers = [NullHandler()]
6462

6563
# 对于 bcrypt 有一个 适配的问题,就藏起来吧
6664
logging.getLogger("passlib").setLevel(logging.ERROR)
67-
# 完全抑制 httpx 的日志输出
68-
# httpx_logger = logging.getLogger("httpx")
69-
# http_coro_logger = logging.getLogger("httpcore")
70-
# httpx_logger.setLevel(logging.WARNING)
71-
# http_coro_logger.setLevel(logging.WARNING)
72-
# httpx_logger.addHandler(NullHandler())
73-
# http_coro_logger.addHandler(NullHandler())
74-
# httpx_logger.propagate = False
75-
# http_coro_logger.propagate = False
Lines changed: 32 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,35 @@
1-
from typing import Literal
1+
from typing import Self
22

3-
from pydantic import BaseModel, Field, validator
4-
from pydantic import ConfigDict as ConfigDict
3+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
54

65

76
class Program(BaseModel):
87
# rss_time must be greater than 300,if less than 300,it will be set to 300
98
rss_time: int = Field(default=900, description="Sleep time")
10-
# rename_time must be greater than 30,if less than 0,it will be set to 30
11-
rename_time: int = Field(default=60, description="Rename times in one loop")
129
webui_port: int = Field(default=7892, description="WebUI port")
1310

14-
@validator("rss_time")
15-
def validate_rss_time(cls, v: int) -> int:
16-
if v < 300:
17-
return 300
18-
return v
19-
20-
@validator("rename_time")
21-
def validate_rename_time(cls, v: int) -> int:
22-
if v < 30:
23-
return 30
24-
return v
11+
@model_validator(mode="after")
12+
def validate_rss_time(self) -> Self:
13+
self.rss_time = max(self.rss_time, 300)
14+
return self
2515

2616

2717
class Downloader(BaseModel):
2818
type: str = Field(default="qbittorrent", description="Downloader type")
2919
path: str = Field(default="/downloads/Bangumi", description="Downloader path")
30-
host: str = Field(default="172.17.0.1:8080", alias="host", description="Downloader host")
31-
ssl: bool = Field(default=False, description="Downloader ssl")
32-
33-
class Config:
34-
extra: str = "allow" # This allows extra fields not defined in the model
35-
36-
@validator("host", pre=True)
37-
def validate_host(cls, value: str) -> str:
38-
# 如果输入值没有以 http:// 或 https:// 开头,自动加上 http://
39-
if not value.startswith(("http://", "https://")):
40-
value = f"http://{value}"
41-
return value
42-
43-
# username: str = Field("admin", alias="username", description="Downloader username")
44-
# password: str = Field(
45-
# "adminadmin", alias="password", description="Downloader password"
46-
# )
47-
48-
49-
class QbDownloader(Downloader):
50-
type: str = Field(default="qbittorrent", description="Downloader type")
5120
host: str = Field(
5221
default="172.17.0.1:8080", alias="host", description="Downloader host"
5322
)
54-
username: str = Field(
55-
default="admin", alias="username", description="Downloader username"
56-
)
57-
password: str = Field(
58-
default="adminadmin", alias="password", description="Downloader password"
59-
)
60-
path: str = Field(default="/downloads/Bangumi", description="Downloader path")
6123
ssl: bool = Field(default=False, description="Downloader ssl")
6224

25+
model_config = ConfigDict(extra="allow")
6326

64-
class TrDownloader(Downloader):
65-
type: str = Field(default="transmission", description="Downloader type")
66-
host_: str = Field(
67-
default="172.17.0.1:9091", alias="host", description="Downloader host"
68-
)
69-
username_: str = Field(
70-
default="admin", alias="username", description="Downloader username"
71-
)
72-
password_: str = Field(
73-
default="admin", alias="password", description="Downloader password"
74-
)
75-
path: str = Field(default="/downloads/Bangumi", description="Downloader path")
76-
ssl: bool = Field(default=False, description="Downloader ssl")
27+
@model_validator(mode="after")
28+
def validate_host(self) -> Self:
29+
# 如果输入值没有以 http:// 或 https:// 开头,自动加上 http://
30+
if not self.host.startswith(("http://", "https://")):
31+
self.host = f"http://{self.host}"
32+
return self
7733

7834

7935
class RSSParser(BaseModel):
@@ -108,14 +64,6 @@ class Proxy(BaseModel):
10864
username: str = Field(default="", alias="username", description="Proxy username")
10965
password: str = Field(default="", alias="password", description="Proxy password")
11066

111-
# @property
112-
# def username(self):
113-
# return expandvars(self.username_)
114-
#
115-
# @property
116-
# def password(self):
117-
# return expandvars(self.password_)
118-
11967

12068
class Notification(BaseModel):
12169
enable: bool = Field(default=False, description="Enable notification")
@@ -125,33 +73,25 @@ class Notification(BaseModel):
12573
default="", alias="chat_id", description="Notification chat id"
12674
)
12775

128-
# @property
129-
# def token(self):
130-
# return expandvars(self.token_)
131-
#
132-
# @property
133-
# def chat_id(self):
134-
# return expandvars(self.chat_id_)
135-
13676

137-
class ExperimentalOpenAI(BaseModel):
138-
enable: bool = Field(False, description="Enable experimental OpenAI")
139-
api_key: str = Field("", description="OpenAI api key")
140-
api_base: str = Field(
141-
"https://api.openai.com/v1", description="OpenAI api base url"
142-
)
143-
api_type: Literal["azure", "openai"] = Field(
144-
"openai", description="OpenAI api type, usually for azure"
145-
)
146-
api_version: str = Field(
147-
"2023-05-15", description="OpenAI api version, only for Azure"
148-
)
149-
model: str = Field(
150-
"gpt-3.5-turbo", description="OpenAI model, ignored when api type is azure"
151-
)
152-
deployment_id: str = Field(
153-
"", description="Azure OpenAI deployment id, ignored when api type is openai"
154-
)
77+
# class ExperimentalOpenAI(BaseModel):
78+
# enable: bool = Field(False, description="Enable experimental OpenAI")
79+
# api_key: str = Field("", description="OpenAI api key")
80+
# api_base: str = Field(
81+
# "https://api.openai.com/v1", description="OpenAI api base url"
82+
# )
83+
# api_type: Literal["azure", "openai"] = Field(
84+
# "openai", description="OpenAI api type, usually for azure"
85+
# )
86+
# api_version: str = Field(
87+
# "2023-05-15", description="OpenAI api version, only for Azure"
88+
# )
89+
# model: str = Field(
90+
# "gpt-3.5-turbo", description="OpenAI model, ignored when api type is azure"
91+
# )
92+
# deployment_id: str = Field(
93+
# "", description="Azure OpenAI deployment id, ignored when api type is openai"
94+
# )
15595

15696

15797
class Config(BaseModel):
@@ -163,16 +103,6 @@ class Config(BaseModel):
163103
proxy: Proxy = Proxy()
164104
notification: Notification = Notification()
165105

166-
class Config:
167-
extra = "allow" # This allows extra fields not defined in the model
106+
model_config = ConfigDict(extra="allow")
168107

169108
# experimental_openai: ExperimentalOpenAI = ExperimentalOpenAI()
170-
171-
# @override
172-
# def model_dump(self, *args, by_alias=True, **kwargs):
173-
# return super().model_dump(*args, by_alias=by_alias, **kwargs)
174-
175-
176-
if __name__ == "__main__":
177-
pass
178-
# t = Program(rss_time="1")

backend/src/module/parser/analyser/torrent_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ def torrent_parser(
3636
suffix = Path(torrent_name).suffix
3737
if file_type == "media":
3838
return EpisodeFile(
39-
**media_info.dict(),
39+
**media_info.model_dump(),
4040
media_path=torrent_name,
4141
title=media_info.get_title(),
4242
suffix=suffix,
4343
)
4444
else:
4545
language = get_subtitle_lang(torrent_name)
4646
return SubtitleFile(
47-
**media_info.dict(),
47+
**media_info.model_dump(),
4848
media_path=torrent_name,
4949
language=language,
5050
title=media_info.get_title(),

0 commit comments

Comments
 (0)