Skip to content

Commit 1e65f3f

Browse files
✨ AttachmentSegment 支持传递 url,添加获取 attachments 的方法 (#69)
* 🐛 修复 #43 转发图片消息时附件不可发送 * ♻️ 使用 yarl.URL 解析附件地址
1 parent 693b753 commit 1e65f3f

5 files changed

Lines changed: 446 additions & 15 deletions

File tree

nonebot/adapters/discord/bot.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Union
1+
from http import HTTPStatus
2+
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
23
from typing_extensions import override
34

45
from nonebot.adapters import Bot as BaseBot
56

7+
from nonebot.drivers import Request
68
from nonebot.message import handle_event
9+
from yarl import URL
710

811
from .api import (
912
UNSET,
1013
AllowedMention,
1114
ApiClient,
15+
File,
1216
InteractionCallbackMessage,
1317
InteractionCallbackType,
1418
InteractionResponse,
@@ -28,6 +32,10 @@
2832
from .adapter import Adapter
2933

3034

35+
DISCORD_ATTACHMENT_HOSTS = {"cdn.discordapp.com", "media.discordapp.net"}
36+
AttachmentFetchOnError = Literal["raise", "skip"]
37+
38+
3139
async def _check_reply(bot: "Bot", event: MessageEvent) -> None:
3240
message_reference = event.message_reference
3341
if message_reference is UNSET:
@@ -175,6 +183,138 @@ async def handle_event(self, event: Event) -> None:
175183
_check_at_me(self, event)
176184
await handle_event(self, event)
177185

186+
async def fetch_attachments( # noqa: PLR0913
187+
self,
188+
message: Union[str, Message, MessageSegment],
189+
*,
190+
allowed_hosts: Optional[set[str]] = None,
191+
require_https: bool = True,
192+
timeout: Optional[float] = None,
193+
max_bytes: Optional[int] = None,
194+
prefer_proxy_url: bool = True,
195+
on_error: AttachmentFetchOnError = "raise",
196+
) -> Message:
197+
message = MessageSegment.text(message) if isinstance(message, str) else message
198+
message = message if isinstance(message, Message) else Message(message)
199+
new = message.clone()
200+
201+
if allowed_hosts is None:
202+
allowed_hosts = DISCORD_ATTACHMENT_HOSTS
203+
204+
attachment_segments = new["attachment"] or []
205+
for index, attachment in enumerate(attachment_segments):
206+
if attachment.data["file"] is not None:
207+
continue
208+
209+
url = self._pick_attachment_url(
210+
attachment,
211+
allowed_hosts=allowed_hosts,
212+
require_https=require_https,
213+
prefer_proxy_url=prefer_proxy_url,
214+
)
215+
if url is None:
216+
if on_error == "raise":
217+
msg = (
218+
f"Attachment segment at index {index} has no fetchable "
219+
"url/proxy_url"
220+
)
221+
raise ValueError(msg)
222+
continue
223+
224+
content = await self._fetch_attachment_content(
225+
url,
226+
timeout=timeout,
227+
max_bytes=max_bytes,
228+
)
229+
if content is None:
230+
if on_error == "raise":
231+
msg = (
232+
f"Failed to fetch attachment content for segment "
233+
f"at index {index} from URL {url}"
234+
)
235+
raise ValueError(msg)
236+
continue
237+
238+
attachment.data["file"] = File(
239+
filename=attachment.data["attachment"].filename,
240+
content=content,
241+
)
242+
243+
return new
244+
245+
@staticmethod
246+
def _pick_attachment_url(
247+
attachment: MessageSegment,
248+
*,
249+
allowed_hosts: set[str],
250+
require_https: bool,
251+
prefer_proxy_url: bool,
252+
) -> Optional[str]:
253+
urls = []
254+
if prefer_proxy_url:
255+
urls.extend(
256+
[
257+
attachment.data.get("proxy_url"),
258+
attachment.data.get("url"),
259+
]
260+
)
261+
else:
262+
urls.extend(
263+
[
264+
attachment.data.get("url"),
265+
attachment.data.get("proxy_url"),
266+
]
267+
)
268+
269+
for candidate in urls:
270+
if isinstance(candidate, str) and Bot._is_supported_attachment_url(
271+
candidate,
272+
allowed_hosts=allowed_hosts,
273+
require_https=require_https,
274+
):
275+
return candidate
276+
277+
return None
278+
279+
@staticmethod
280+
def _is_supported_attachment_url(
281+
url: str, *, allowed_hosts: set[str], require_https: bool
282+
) -> bool:
283+
parsed = URL(url)
284+
scheme_ok = parsed.scheme == "https" if require_https else bool(parsed.scheme)
285+
return (
286+
scheme_ok and isinstance(parsed.host, str) and parsed.host in allowed_hosts
287+
)
288+
289+
async def _fetch_attachment_content(
290+
self,
291+
url: str,
292+
*,
293+
timeout: Optional[float],
294+
max_bytes: Optional[int],
295+
) -> Optional[bytes]:
296+
try:
297+
request = Request(
298+
method="GET",
299+
url=url,
300+
timeout=timeout or self.adapter.discord_config.discord_api_timeout,
301+
proxy=self.adapter.discord_config.discord_proxy,
302+
)
303+
response = await self.adapter.request(request)
304+
if response.status_code != HTTPStatus.OK or not response.content:
305+
return None
306+
content = (
307+
response.content.encode()
308+
if isinstance(response.content, str)
309+
else response.content
310+
)
311+
if max_bytes is not None and len(content) > max_bytes:
312+
return None
313+
return content # noqa: TRY300
314+
except Exception as e:
315+
log("DEBUG", f"Failed to fetch attachment content from URL {url}: {e!r}", e)
316+
return None
317+
178318
async def send_to(
179319
self,
180320
channel_id: SnowflakeType,
@@ -183,6 +323,9 @@ async def send_to(
183323
nonce: Union[int, str, None] = None,
184324
allowed_mentions: Optional[AllowedMention] = None,
185325
) -> MessageGet:
326+
message = MessageSegment.text(message) if isinstance(message, str) else message
327+
message = message if isinstance(message, Message) else Message(message)
328+
message = message.sendable()
186329
message_data = parse_message(message)
187330

188331
return await self.create_message(
@@ -222,6 +365,8 @@ async def send(
222365
message model
223366
"""
224367
message = MessageSegment.text(message) if isinstance(message, str) else message
368+
message = message if isinstance(message, Message) else Message(message)
369+
message = message.sendable()
225370
if isinstance(event, InteractionCreateEvent):
226371
message_data = parse_message(message)
227372
response = InteractionResponse(
@@ -250,7 +395,6 @@ async def send(
250395
if not isinstance(event, MessageEvent) or not event.channel_id or not event.id:
251396
msg = "Event cannot be replied to!"
252397
raise RuntimeError(msg)
253-
message = message if isinstance(message, Message) else Message(message)
254398
if mention_sender or at_sender:
255399
message.insert(0, MessageSegment.mention_user(event.user_id))
256400
if reply_message:

nonebot/adapters/discord/message.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections.abc import Iterable
2+
from copy import deepcopy
23
from dataclasses import dataclass
34
import datetime
45
import re
@@ -61,6 +62,9 @@ def attachment(
6162
file: Union[str, File, AttachmentSend],
6263
description: Optional[str] = None,
6364
content: Optional[bytes] = None,
65+
*,
66+
url: Optional[str] = None,
67+
proxy_url: Optional[str] = None,
6468
) -> "AttachmentSegment":
6569
if isinstance(file, str):
6670
_filename = file
@@ -85,6 +89,8 @@ def attachment(
8589
filename=_filename, description=_description
8690
),
8791
"file": None,
92+
"url": url,
93+
"proxy_url": proxy_url,
8894
},
8995
)
9096
return AttachmentSegment(
@@ -98,6 +104,8 @@ def attachment(
98104
if isinstance(_filename, str)
99105
else None
100106
),
107+
"url": url,
108+
"proxy_url": proxy_url,
101109
},
102110
)
103111

@@ -475,6 +483,8 @@ def _validate(cls, value) -> Self: # noqa: ANN001
475483
class AttachmentData(TypedDict):
476484
attachment: AttachmentSend
477485
file: Optional[File]
486+
url: Optional[str]
487+
proxy_url: Optional[str]
478488

479489

480490
@dataclass
@@ -507,6 +517,17 @@ def _validate(cls, value) -> Self: # noqa: ANN001
507517
file, File
508518
):
509519
instance.data["file"] = type_validate_python(File, file)
520+
url = instance.data.get("url")
521+
if url is not None and not isinstance(url, str):
522+
msg = f"Expected str for AttachmentSegment.data['url'], got {type(url)}"
523+
raise TypeError(msg)
524+
proxy_url = instance.data.get("proxy_url")
525+
if proxy_url is not None and not isinstance(proxy_url, str):
526+
msg = (
527+
"Expected str for AttachmentSegment.data['proxy_url'], "
528+
f"got {type(proxy_url)}"
529+
)
530+
raise TypeError(msg)
510531
return instance
511532

512533

@@ -676,7 +697,9 @@ def from_guild_message(cls, message: MessageGet) -> "Message":
676697
if isinstance(attachment.description, str)
677698
else None
678699
),
679-
)
700+
),
701+
url=attachment.url,
702+
proxy_url=attachment.proxy_url,
680703
)
681704
for attachment in message.attachments
682705
)
@@ -706,6 +729,29 @@ def extract_content(self) -> str:
706729
)
707730
)
708731

732+
def clone(self) -> "Message":
733+
new = self.__class__()
734+
for segment in self:
735+
new.append(
736+
type_validate_python(
737+
MessageSegment,
738+
{
739+
"type": segment.type,
740+
"data": deepcopy(segment.data),
741+
},
742+
)
743+
)
744+
return new
745+
746+
def sendable(self) -> "Message":
747+
new = self.clone()
748+
attachments_segment = new["attachment"] or None
749+
if attachments_segment is not None:
750+
for index, attachment in enumerate(attachments_segment):
751+
if attachment.data["file"] is None:
752+
raise ValueError(_get_unsendable_attachment_msg(index, attachment))
753+
return new
754+
709755

710756
def parse_message(message: Union[Message, MessageSegment, str]) -> dict[str, Any]:
711757
message = MessageSegment.text(message) if isinstance(message, str) else message
@@ -733,17 +779,8 @@ def parse_message(message: Union[Message, MessageSegment, str]) -> dict[str, Any
733779
layout_type=poll.layout_type,
734780
)
735781

736-
attachments = None
737-
files = None
738-
if attachments_segment := (message["attachment"] or None):
739-
attachments = [
740-
attachment.data["attachment"] for attachment in attachments_segment
741-
]
742-
files = [
743-
attachment.data["file"]
744-
for attachment in attachments_segment
745-
if attachment.data["file"] is not None
746-
]
782+
attachments, files = extract_attachments(message)
783+
747784
return {
748785
k: v
749786
for k, v in {
@@ -758,3 +795,38 @@ def parse_message(message: Union[Message, MessageSegment, str]) -> dict[str, Any
758795
}.items()
759796
if v is not None
760797
}
798+
799+
800+
def extract_attachments(
801+
message: Message,
802+
) -> tuple[Optional[list[AttachmentSend]], Optional[list[File]]]:
803+
attachments_segment = message["attachment"] or None
804+
if not attachments_segment:
805+
return None, None
806+
807+
attachments_list: list[AttachmentSend] = []
808+
files_list: list[File] = []
809+
for index, attachment in enumerate(attachments_segment):
810+
file = attachment.data["file"]
811+
if file is None:
812+
raise ValueError(_get_unsendable_attachment_msg(index, attachment))
813+
attachments_list.append(attachment.data["attachment"])
814+
files_list.append(file)
815+
816+
attachments = attachments_list or None
817+
files = files_list or None
818+
return attachments, files
819+
820+
821+
def _get_unsendable_attachment_msg(index: int, attachment: MessageSegment) -> str:
822+
if attachment.data.get("url") or attachment.data.get("proxy_url"):
823+
return (
824+
f"Attachment segment at index {index} is not sendable because file "
825+
"content is missing; call "
826+
"`await bot.fetch_attachments(message)` first"
827+
)
828+
return (
829+
f"Attachment segment at index {index} is not sendable because file "
830+
"content is missing; provide `content=` in "
831+
"MessageSegment.attachment(...)"
832+
)

0 commit comments

Comments
 (0)