Skip to content

Commit 3ee5da6

Browse files
committed
fix: apply fallback chat models to background wakeups
1 parent 6067a70 commit 3ee5da6

2 files changed

Lines changed: 80 additions & 3 deletions

File tree

astrbot/core/astr_agent_tool_exec.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,11 +543,12 @@ async def _wake_main_agent_for_background_result(
543543
message_type=session.message_type,
544544
)
545545
cron_event.role = event.role
546+
cfg = ctx.get_config(umo=event.unified_msg_origin) or {}
547+
provider_settings = cfg.get("provider_settings") or {}
546548
config = MainAgentBuildConfig(
547549
tool_call_timeout=run_context.tool_call_timeout,
548-
streaming_response=ctx.get_config()
549-
.get("provider_settings", {})
550-
.get("stream", False),
550+
streaming_response=provider_settings.get("stream", False),
551+
provider_settings=provider_settings,
551552
)
552553

553554
req = ProviderRequest()

tests/unit/test_astr_agent_tool_exec.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from types import SimpleNamespace
2+
from unittest.mock import AsyncMock
23

34
import mcp
45
import pytest
@@ -19,6 +20,7 @@ class _DummyEvent:
1920
def __init__(self, message_components: list[object] | None = None) -> None:
2021
self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session"
2122
self.message_obj = SimpleNamespace(message=message_components or [])
23+
self.role = "member"
2224

2325
def get_extra(self, _key: str):
2426
return None
@@ -36,6 +38,15 @@ def _build_run_context(message_components: list[object] | None = None):
3638
return ContextWrapper(context=ctx)
3739

3840

41+
class _DoneRunner:
42+
async def step_until_done(self, _max_step):
43+
for item in ():
44+
yield item
45+
46+
def get_final_llm_resp(self):
47+
return SimpleNamespace(role="assistant", completion_text="done")
48+
49+
3950
def test_build_handoff_toolset_keeps_permission_guards_for_default_tools():
4051
mgr = FunctionToolManager()
4152
plugin_tool = FunctionTool(
@@ -354,6 +365,71 @@ async def _fake_tool_loop_agent(**kwargs):
354365
assert captured["tool_call_timeout"] == 120
355366

356367

368+
@pytest.mark.asyncio
369+
async def test_background_wakeup_passes_provider_settings_to_main_agent(
370+
monkeypatch: pytest.MonkeyPatch,
371+
):
372+
provider_settings = {
373+
"fallback_chat_models": ["fallback-provider"],
374+
"request_max_retries": 3,
375+
"stream": True,
376+
}
377+
captured: dict = {}
378+
379+
async def _fake_get_session_conv(**_kwargs):
380+
return SimpleNamespace(history="[]")
381+
382+
async def _fake_build_main_agent(**kwargs):
383+
captured.update(kwargs)
384+
return SimpleNamespace(agent_runner=_DoneRunner())
385+
386+
monkeypatch.setattr(
387+
"astrbot.core.astr_main_agent._get_session_conv",
388+
_fake_get_session_conv,
389+
)
390+
monkeypatch.setattr(
391+
"astrbot.core.astr_main_agent.build_main_agent",
392+
_fake_build_main_agent,
393+
)
394+
monkeypatch.setattr(
395+
"astrbot.core.astr_agent_tool_exec.persist_agent_history",
396+
AsyncMock(),
397+
)
398+
399+
send_tool = FunctionTool(
400+
name="send_message_to_user",
401+
description="send",
402+
parameters={"type": "object", "properties": {}},
403+
)
404+
context = SimpleNamespace(
405+
get_config=lambda **_kwargs: {"provider_settings": provider_settings},
406+
get_llm_tool_manager=lambda: SimpleNamespace(
407+
get_builtin_tool=lambda _tool_cls: send_tool
408+
),
409+
conversation_manager=SimpleNamespace(),
410+
)
411+
run_context = ContextWrapper(
412+
context=SimpleNamespace(event=_DummyEvent([]), context=context),
413+
tool_call_timeout=456,
414+
)
415+
416+
await FunctionToolExecutor._wake_main_agent_for_background_result(
417+
run_context,
418+
task_id="task-id",
419+
tool_name="long_tool",
420+
result_text="ok",
421+
tool_args={},
422+
note="task finished",
423+
summary_name="BackgroundTask",
424+
)
425+
426+
config = captured["config"]
427+
assert config.tool_call_timeout == 456
428+
assert config.streaming_response == provider_settings["stream"]
429+
assert config.provider_settings == provider_settings
430+
assert config.provider_settings["fallback_chat_models"] == ["fallback-provider"]
431+
432+
357433
@pytest.mark.asyncio
358434
async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root(
359435
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)