Skip to content

Commit 3d4c4ed

Browse files
authored
fix: validate plugin install sources (#9061)
1 parent c1cc74b commit 3d4c4ed

21 files changed

Lines changed: 1663 additions & 256 deletions

File tree

astrbot/core/star/star_manager.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
from .filter.permission import PermissionType, PermissionTypeFilter
5353
from .star import star_map, star_registry
5454
from .star_handler import EventType, star_handlers_registry
55-
from .updator import PluginUpdator
55+
from .updator import PLUGIN_METADATA_FILENAMES, PluginUpdator
5656

5757
try:
5858
from watchfiles import PythonFilter, awatch
@@ -465,38 +465,44 @@ async def _import_plugin_with_dependency_recovery(
465465

466466
@staticmethod
467467
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
468-
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
468+
"""Load plugin metadata from metadata.yaml or metadata.yml.
469469
470-
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
470+
Args:
471+
plugin_path: Plugin directory path.
472+
plugin_obj: Deprecated compatibility argument; ignored.
473+
474+
Returns:
475+
Loaded plugin metadata, or None if no metadata file exists.
471476
"""
477+
del plugin_obj
472478
metadata = None
479+
metadata_label = "metadata.yaml"
480+
plugin_root = Path(plugin_path)
473481

474-
if not os.path.exists(plugin_path):
482+
if not plugin_root.exists():
475483
raise Exception("插件不存在。")
476484

477-
if os.path.exists(os.path.join(plugin_path, "metadata.yaml")):
478-
with open(
479-
os.path.join(plugin_path, "metadata.yaml"),
480-
encoding="utf-8",
481-
) as f:
485+
metadata_path = next(
486+
(
487+
plugin_root / filename
488+
for filename in PLUGIN_METADATA_FILENAMES
489+
if (plugin_root / filename).exists()
490+
),
491+
None,
492+
)
493+
if metadata_path:
494+
metadata_label = metadata_path.name
495+
with metadata_path.open(encoding="utf-8") as f:
482496
metadata = yaml.safe_load(f)
483-
elif plugin_obj and hasattr(plugin_obj, "info"):
484-
# 使用 info() 函数
485-
metadata = plugin_obj.info()
486497

487498
if isinstance(metadata, dict):
488499
if "desc" not in metadata and "description" in metadata:
489500
metadata["desc"] = metadata["description"]
490501

491-
if (
492-
"name" not in metadata
493-
or "desc" not in metadata
494-
or "version" not in metadata
495-
or "author" not in metadata
496-
):
497-
raise Exception(
498-
"插件元数据信息不完整。name, desc, version, author 是必须的字段。",
499-
)
502+
try:
503+
PluginUpdator.validate_plugin_metadata(metadata, metadata_label)
504+
except ValueError as exc:
505+
raise Exception(f"插件元数据校验失败:{exc!s}") from exc
500506
metadata = StarMetadata(
501507
name=metadata["name"],
502508
author=metadata["author"],
@@ -577,32 +583,42 @@ def _normalize_plugin_dir_name(plugin_name: str) -> str:
577583
def _validate_importable_name(plugin_name: str) -> None:
578584
if "/" in plugin_name or "\\" in plugin_name:
579585
raise ValueError(
580-
"metadata.yaml 中 name 含有路径分隔符,不可用于 importlib 加载。"
586+
"metadata 文件中 name 含有路径分隔符,不可用于 importlib 加载。"
581587
)
582588
if not plugin_name.isidentifier() or keyword.iskeyword(plugin_name):
583589
raise Exception(
584-
"metadata.yaml 中 name 不是合法的模块名称(应为合法 Python 标识符且非关键字)。"
590+
"metadata 文件中 name 不是合法的模块名称(应为合法 Python 标识符且非关键字)。"
585591
)
586592

587593
@staticmethod
588594
def _get_plugin_dir_name_from_metadata(plugin_path: str) -> str:
589-
metadata_path = os.path.join(plugin_path, "metadata.yaml")
590-
if not os.path.exists(metadata_path):
591-
raise Exception("未找到 metadata.yaml,无法获取插件目录名。")
595+
plugin_root = Path(plugin_path)
596+
metadata_path = next(
597+
(
598+
plugin_root / filename
599+
for filename in PLUGIN_METADATA_FILENAMES
600+
if (plugin_root / filename).exists()
601+
),
602+
None,
603+
)
604+
if metadata_path is None:
605+
raise Exception(
606+
"未找到 metadata.yaml 或 metadata.yml,无法获取插件目录名。"
607+
)
592608

593-
with open(metadata_path, encoding="utf-8") as f:
609+
with metadata_path.open(encoding="utf-8") as f:
594610
metadata = yaml.safe_load(f)
595611

596612
if not isinstance(metadata, dict):
597-
raise Exception("metadata.yaml 格式错误。")
613+
raise Exception(f"{metadata_path.name} 格式错误。")
598614

599615
plugin_name = metadata.get("name")
600616
if not isinstance(plugin_name, str) or not plugin_name.strip():
601-
raise Exception("metadata.yaml 中缺少 name 字段。")
617+
raise Exception(f"{metadata_path.name} 中缺少 name 字段。")
602618

603619
plugin_dir_name = PluginManager._normalize_plugin_dir_name(plugin_name)
604620
if not plugin_dir_name:
605-
raise Exception("metadata.yaml 中 name 字段内容非法。")
621+
raise Exception(f"{metadata_path.name} 中 name 字段内容非法。")
606622
PluginManager._validate_importable_name(plugin_dir_name)
607623
return plugin_dir_name
608624

astrbot/core/star/updator.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import os
22
import zipfile
33

4+
import yaml
5+
46
from astrbot.core import logger
57
from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path
68
from astrbot.core.utils.io import ensure_dir, remove_dir
79

810
from ..star.star import StarMetadata
911
from ..updator import RepoZipUpdator
1012

13+
PLUGIN_METADATA_FILENAMES = ("metadata.yaml", "metadata.yml")
14+
PLUGIN_METADATA_REQUIRED_FIELDS = ("name", "desc", "version", "author")
15+
1116

1217
class PluginUpdator(RepoZipUpdator):
1318
def __init__(self, repo_mirror: str = "", verify: str | bool | None = None) -> None:
@@ -58,6 +63,7 @@ async def update(
5863
elif repo_url:
5964
await self.download_from_repo_url(plugin_path, repo_url, proxy=proxy)
6065

66+
self.validate_plugin_archive(plugin_path + ".zip")
6167
try:
6268
remove_dir(plugin_path)
6369
except BaseException as e:
@@ -69,7 +75,135 @@ async def update(
6975

7076
return plugin_path
7177

78+
@classmethod
79+
def find_plugin_metadata_entry(cls, entries: list[str]) -> str | None:
80+
"""Find AstrBot plugin metadata in archive entries.
81+
82+
Args:
83+
entries: Zip archive member names.
84+
85+
Returns:
86+
The original archive entry name for plugin metadata, or None.
87+
"""
88+
update_dir = cls._resolve_archive_root_dir(entries)
89+
portable_update_dir = os.path.normpath(update_dir).replace("\\", "/")
90+
if portable_update_dir == ".":
91+
portable_update_dir = ""
92+
93+
entries_by_portable_path = {}
94+
for entry in entries:
95+
portable_entry = os.path.normpath(entry).replace("\\", "/")
96+
if portable_entry in ("", "."):
97+
continue
98+
entries_by_portable_path[portable_entry] = entry
99+
100+
metadata_candidates = (
101+
[
102+
f"{portable_update_dir}/{filename}"
103+
for filename in PLUGIN_METADATA_FILENAMES
104+
]
105+
if portable_update_dir
106+
else list(PLUGIN_METADATA_FILENAMES)
107+
)
108+
for candidate in metadata_candidates:
109+
if candidate in entries_by_portable_path:
110+
return entries_by_portable_path[candidate]
111+
return None
112+
113+
@staticmethod
114+
def validate_plugin_metadata(metadata: object, metadata_label: str) -> None:
115+
"""Validate AstrBot plugin metadata content.
116+
117+
Args:
118+
metadata: Parsed metadata YAML content.
119+
metadata_label: Metadata filename or archive entry for error messages.
120+
121+
Raises:
122+
ValueError: If metadata is malformed or misses required fields.
123+
"""
124+
if not isinstance(metadata, dict):
125+
raise ValueError(f"{metadata_label} 格式错误。")
126+
127+
normalized_metadata = dict(metadata)
128+
if "desc" not in normalized_metadata and "description" in normalized_metadata:
129+
normalized_metadata["desc"] = normalized_metadata["description"]
130+
131+
missing_fields = [
132+
field
133+
for field in PLUGIN_METADATA_REQUIRED_FIELDS
134+
if field not in normalized_metadata
135+
]
136+
if missing_fields:
137+
raise ValueError(
138+
f"{metadata_label} 中缺少必需字段: {', '.join(missing_fields)}。"
139+
)
140+
141+
invalid_fields = [
142+
field
143+
for field in PLUGIN_METADATA_REQUIRED_FIELDS
144+
if not isinstance(normalized_metadata[field], str)
145+
or not normalized_metadata[field].strip()
146+
]
147+
if invalid_fields:
148+
raise ValueError(
149+
f"{metadata_label} 中字段 {', '.join(invalid_fields)} 必须是非空字符串。"
150+
)
151+
152+
@classmethod
153+
def inspect_plugin_archive(cls, zip_path: str) -> dict[str, object]:
154+
"""Inspect plugin metadata in an AstrBot plugin archive.
155+
156+
Args:
157+
zip_path: Path to the plugin archive.
158+
159+
Returns:
160+
A dict containing the metadata entry and parsed metadata.
161+
162+
Raises:
163+
ValueError: If the archive is not a valid AstrBot plugin.
164+
"""
165+
try:
166+
with zipfile.ZipFile(zip_path, "r") as z:
167+
metadata_entry = cls.find_plugin_metadata_entry(z.namelist())
168+
if metadata_entry is None:
169+
raise ValueError(
170+
"压缩包不是合法的 AstrBot 插件:未找到 metadata.yaml 或 metadata.yml。"
171+
)
172+
173+
try:
174+
metadata_text = z.read(metadata_entry).decode("utf-8")
175+
metadata = yaml.safe_load(metadata_text)
176+
except UnicodeDecodeError as exc:
177+
raise ValueError(f"{metadata_entry} 必须使用 UTF-8 编码。") from exc
178+
except yaml.YAMLError as exc:
179+
raise ValueError(f"{metadata_entry} 格式错误。") from exc
180+
181+
cls.validate_plugin_metadata(metadata, metadata_entry)
182+
return {
183+
"metadata_entry": metadata_entry,
184+
"metadata": metadata,
185+
}
186+
except zipfile.BadZipFile as exc:
187+
raise ValueError("插件压缩包格式错误。") from exc
188+
189+
@classmethod
190+
def validate_plugin_archive(cls, zip_path: str) -> str:
191+
"""Validate that an archive contains a valid AstrBot plugin.
192+
193+
Args:
194+
zip_path: Path to the plugin archive.
195+
196+
Returns:
197+
The archive entry name of the plugin metadata file.
198+
199+
Raises:
200+
ValueError: If the archive is not a valid AstrBot plugin.
201+
"""
202+
inspection = cls.inspect_plugin_archive(zip_path)
203+
return str(inspection["metadata_entry"])
204+
72205
def unzip_file(self, zip_path: str, target_dir: str) -> None:
206+
self.validate_plugin_archive(zip_path)
73207
ensure_dir(target_dir)
74208
logger.info(f"Extracting archive: {zip_path}")
75209
with zipfile.ZipFile(zip_path, "r") as z:

astrbot/core/zip_updator.py

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,55 @@ def _truncate_response_body(body: str, max_len: int = 1000) -> str:
5454
return body
5555
return body[:max_len] + "...[truncated]"
5656

57+
async def fetch_github_default_branch(self, author: str, repo: str) -> str | None:
58+
"""Fetch the default branch for a GitHub repository.
59+
60+
Args:
61+
author: GitHub repository owner.
62+
repo: GitHub repository name.
63+
64+
Returns:
65+
The default branch name, or None if it cannot be resolved.
66+
"""
67+
url = f"https://api.github.com/repos/{author}/{repo}"
68+
try:
69+
async with self._create_httpx_client(timeout=10.0) as client:
70+
response = await client.get(url)
71+
response.raise_for_status()
72+
repo_info = response.json()
73+
except Exception as exc:
74+
logger.debug("获取 GitHub 默认分支失败 %s/%s: %s", author, repo, exc)
75+
return None
76+
77+
default_branch = str(repo_info.get("default_branch") or "").strip()
78+
return default_branch or None
79+
80+
async def resolve_github_source_branch(
81+
self,
82+
repo_url: str,
83+
) -> tuple[str, str, str]:
84+
"""Resolve the GitHub branch used for repository source downloads.
85+
86+
Args:
87+
repo_url: GitHub repository URL, optionally with a tree branch.
88+
89+
Returns:
90+
Repository owner, name, and resolved source branch.
91+
92+
Raises:
93+
ValueError: If the repository URL is invalid.
94+
"""
95+
author, repo, branch = self.parse_github_url(repo_url)
96+
if branch:
97+
return author, repo, branch
98+
99+
default_branch = await self.fetch_github_default_branch(author, repo)
100+
if default_branch:
101+
return author, repo, default_branch
102+
103+
logger.info("未能获取 %s/%s 的默认分支,将尝试 main 分支", author, repo)
104+
return author, repo, "main"
105+
57106
async def _download_file(
58107
self,
59108
url: str,
@@ -225,32 +274,13 @@ async def check_update(
225274
async def download_from_repo_url(
226275
self, target_path: str, repo_url: str, proxy=""
227276
) -> None:
228-
author, repo, branch = self.parse_github_url(repo_url)
277+
author, repo, branch = await self.resolve_github_source_branch(repo_url)
229278

230279
logger.info(f"正在下载更新 {repo} ...")
231-
232-
if branch:
233-
logger.info(f"正在从指定分支 {branch} 下载 {author}/{repo}")
234-
release_url = (
235-
f"https://github.com/{author}/{repo}/archive/refs/heads/{branch}.zip"
236-
)
237-
else:
238-
try:
239-
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
240-
releases = await self.fetch_release_info(url=release_url)
241-
except Exception as e:
242-
logger.warning(
243-
f"获取 {author}/{repo} 的 GitHub Releases 失败: {e},将尝试下载默认分支",
244-
)
245-
releases = []
246-
if not releases:
247-
# 如果没有最新版本,下载默认分支
248-
logger.info(f"正在从默认分支下载 {author}/{repo}")
249-
release_url = (
250-
f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
251-
)
252-
else:
253-
release_url = releases[0]["zipball_url"]
280+
logger.info(f"正在从分支 {branch} 下载 {author}/{repo}")
281+
release_url = (
282+
f"https://github.com/{author}/{repo}/archive/refs/heads/{branch}.zip"
283+
)
254284

255285
if proxy:
256286
proxy = proxy.rstrip("/")

0 commit comments

Comments
 (0)