Skip to content

Commit 93a6152

Browse files
authored
feat: add temporary extra user content parts (#7976)
* feat: add temporary extra user content parts * fix: 3.10
1 parent fff9c8e commit 93a6152

6 files changed

Lines changed: 108 additions & 4 deletions

File tree

astrbot/core/agent/message.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
22
# License: Apache License 2.0
33

4-
from typing import Any, ClassVar, Literal, cast
4+
from typing import Any, ClassVar, Literal, TypeVar, cast
55

66
from pydantic import (
77
BaseModel,
@@ -13,13 +13,16 @@
1313
)
1414
from pydantic_core import core_schema
1515

16+
ContentPartT = TypeVar("ContentPartT", bound="ContentPart")
17+
1618

1719
class ContentPart(BaseModel):
1820
"""A part of the content in a message."""
1921

2022
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
2123

2224
type: Literal["text", "think", "image_url", "audio_url"]
25+
_no_save: bool = PrivateAttr(default=False)
2326

2427
def __init_subclass__(cls, **kwargs: Any) -> None:
2528
super().__init_subclass__(**kwargs)
@@ -50,7 +53,10 @@ def validate_content_part(value: Any) -> Any:
5053
if not isinstance(type_value, str):
5154
raise ValueError(f"Cannot validate {value} as ContentPart")
5255
target_class = cls.__content_part_registry[type_value]
53-
return target_class.model_validate(value)
56+
part = target_class.model_validate(value)
57+
if cast(dict[str, Any], value).get("_no_save"):
58+
part._no_save = True
59+
return part
5460

5561
raise ValueError(f"Cannot validate {value} as ContentPart")
5662

@@ -59,6 +65,17 @@ def validate_content_part(value: Any) -> Any:
5965
# for subclasses, use the default schema
6066
return handler(source_type)
6167

68+
def mark_as_temp(self: ContentPartT) -> ContentPartT:
69+
"""Mark this content part as provider-facing only, not persisted."""
70+
self._no_save = True
71+
return self
72+
73+
def model_dump_for_context(self) -> dict[str, Any]:
74+
data = self.model_dump()
75+
if self._no_save:
76+
data["_no_save"] = True
77+
return data
78+
6279

6380
class TextPart(ContentPart):
6481
"""
@@ -329,7 +346,14 @@ def dump_messages_with_checkpoints(messages: list[Message]) -> list[dict]:
329346
"""Dump runtime messages and reinsert bound checkpoint segments."""
330347
dumped: list[dict] = []
331348
for message in messages:
332-
dumped.append(message.model_dump())
349+
message_data = message.model_dump()
350+
if isinstance(message.content, list):
351+
message_data["content"] = [
352+
part.model_dump()
353+
for part in message.content
354+
if not getattr(part, "_no_save", False)
355+
]
356+
dumped.append(message_data)
333357
if message._checkpoint_after is not None:
334358
dumped.append(
335359
CheckpointMessageSegment(content=message._checkpoint_after).model_dump()

astrbot/core/provider/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def assemble_context(self) -> dict:
206206
# 2. 额外的内容块(系统提醒、指令等)
207207
if self.extra_user_content_parts:
208208
for part in self.extra_user_content_parts:
209-
content_blocks.append(part.model_dump())
209+
content_blocks.append(part.model_dump_for_context())
210210

211211
# 3. 图片内容
212212
if self.image_urls:

docs/en/dev/star/guides/listen-message-event.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,14 @@ async def my_custom_hook_1(self, event: AstrMessageEvent, req: ProviderRequest):
295295
> )
296296
> ```
297297
>
298+
> If the appended content should only affect the current LLM request and should not be persisted into conversation history, call `.mark_as_temp()` to mark it as temporary:
299+
>
300+
> ```python
301+
> req.extra_user_content_parts.append(
302+
> TextPart(text="<runtime_hint>This hint only applies to the current request.</runtime_hint>").mark_as_temp()
303+
> )
304+
> ```
305+
>
298306
> For long-term memory, knowledge bases, or external system queries that may be large or unnecessary for every round, do not put everything directly into the prompt. Prefer registering them as `llm_tool` functions so the model can call them when needed, or retrieve only a small relevant summary in your plugin and append that summary through `extra_user_content_parts`.
299307
300308
> You cannot use yield to send messages here. If you need to send, please use the `event.send()` method directly.

docs/zh/dev/star/guides/listen-message-event.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,14 @@ async def my_custom_hook_1(self, event: AstrMessageEvent, req: ProviderRequest):
314314
> )
315315
> ```
316316
>
317+
> 如果追加的内容只希望参与本轮 LLM 请求,不希望被持久化到会话历史中,可以调用 `.mark_as_temp()` 标记为临时内容(`>= v4.24.0`):
318+
>
319+
> ```python
320+
> req.extra_user_content_parts.append(
321+
> TextPart(text="<runtime_hint>这段提示只在本轮请求中生效。</runtime_hint>").mark_as_temp()
322+
> )
323+
> ```
324+
>
317325
> 对于长期记忆、知识库、外部系统查询等内容量较大或不一定每轮都需要的信息,不建议全部塞进提示词。可以优先注册为 `llm_tool`,让模型在需要时调用;也可以先在插件中检索出本轮真正相关的少量摘要,再放入 `extra_user_content_parts`
318326
319327
#### LLM 请求完成时

docs/zh/dev/star/plugin.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,14 @@ async def my_custom_hook_1(self, event: AstrMessageEvent, req: ProviderRequest):
548548
> )
549549
> ```
550550
>
551+
> 如果追加的内容只希望参与本轮 LLM 请求,不希望被持久化到会话历史中,可以调用 `.mark_as_temp()` 标记为临时内容:
552+
>
553+
> ```python
554+
> req.extra_user_content_parts.append(
555+
> TextPart(text="<runtime_hint>这段提示只在本轮请求中生效。</runtime_hint>").mark_as_temp()
556+
> )
557+
> ```
558+
>
551559
> 对于长期记忆、知识库、外部系统查询等内容量较大或不一定每轮都需要的信息,不建议全部塞进提示词。可以优先注册为 `llm_tool`,让模型在需要时调用;也可以先在插件中检索出本轮真正相关的少量摘要,再放入 `extra_user_content_parts`。
552560

553561
> 这里不能使用 yield 来发送消息。如需发送,请直接使用 `event.send()` 方法。

tests/test_conversation_checkpoint.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
CheckpointData,
55
CheckpointMessageSegment,
66
Message,
7+
TextPart,
78
bind_checkpoint_messages,
89
dump_messages_with_checkpoints,
910
get_checkpoint_id,
1011
strip_checkpoint_messages,
1112
)
13+
from astrbot.core.provider.entities import ProviderRequest
1214
from astrbot.core.provider.provider import Provider
1315
from astrbot.dashboard.routes.chat import ChatRoute
1416

@@ -81,6 +83,60 @@ def test_dump_checkpoint_messages_drops_checkpoint_when_message_is_dropped():
8183
]
8284

8385

86+
def test_dump_messages_filters_temp_content_parts():
87+
messages = [
88+
Message(
89+
role="user",
90+
content=[
91+
TextPart(text="persisted"),
92+
TextPart(text="temporary").mark_as_temp(),
93+
],
94+
),
95+
Message(role="assistant", content="ok"),
96+
]
97+
98+
assert dump_messages_with_checkpoints(messages) == [
99+
{"role": "user", "content": [{"type": "text", "text": "persisted"}]},
100+
{"role": "assistant", "content": "ok"},
101+
]
102+
103+
104+
def test_content_part_no_save_round_trip_from_dict():
105+
message = Message.model_validate(
106+
{
107+
"role": "user",
108+
"content": [
109+
{"type": "text", "text": "persisted"},
110+
{"type": "text", "text": "temporary", "_no_save": True},
111+
],
112+
}
113+
)
114+
115+
assert isinstance(message.content, list)
116+
assert message.content[0]._no_save is False
117+
assert message.content[1]._no_save is True
118+
assert dump_messages_with_checkpoints([message]) == [
119+
{"role": "user", "content": [{"type": "text", "text": "persisted"}]},
120+
]
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_provider_request_assemble_context_preserves_temp_content_part_marker():
125+
request = ProviderRequest(
126+
prompt="hello",
127+
extra_user_content_parts=[TextPart(text="temporary").mark_as_temp()],
128+
)
129+
130+
message = Message.model_validate(await request.assemble_context())
131+
132+
assert isinstance(message.content, list)
133+
assert message.content[1].text == "temporary"
134+
assert message.content[1]._no_save is True
135+
assert dump_messages_with_checkpoints([message]) == [
136+
{"role": "user", "content": [{"type": "text", "text": "hello"}]},
137+
]
138+
139+
84140
def test_provider_ensure_message_to_dicts_skips_checkpoints():
85141
messages = [
86142
Message(role="user", content="hello"),

0 commit comments

Comments
 (0)