Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions astrbot/core/tools/web_search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
from dataclasses import dataclass as std_dataclass
from dataclasses import field
from typing import Awaitable, Callable, TypeVar

import aiohttp
from pydantic import Field
Expand Down Expand Up @@ -58,6 +59,9 @@ class SearchResult:
favicon: str | None = None


_T = TypeVar("_T")


@std_dataclass
class _KeyRotator:
setting_name: str
Expand All @@ -79,6 +83,47 @@ async def get(self, provider_settings: dict) -> str:
self.index = (self.index + 1) % len(keys)
return key

async def ordered_keys(self, provider_settings: dict) -> list[str]:
"""All configured keys, ordered from the current rotation position.

Lets a caller fall through to the next key when one fails (invalid,
out of quota, rate limited) instead of giving up on the first error,
while keeping the round-robin starting point consistent across calls.
"""
keys = provider_settings.get(self.setting_name, [])
if not keys:
raise ValueError(
f"Error: {self.provider_name} API key is not configured in AstrBot."
)
async with self.lock:
start = self.index % len(keys)
self.index = (self.index + 1) % len(keys)
return keys[start:] + keys[:start]
Comment on lines +86 to +101

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To avoid code duplication and make the fallback logic reusable across all search providers (such as BoCha, Brave, Firecrawl, Exa, etc.), we can encapsulate the fallback execution logic directly inside _KeyRotator as a helper method. This keeps the provider-specific search functions clean and focused only on their request/response handling.

    async def ordered_keys(self, provider_settings: dict) -> list[str]:
        """All configured keys, ordered from the current rotation position.

        Lets a caller fall through to the next key when one fails (invalid,
        out of quota, rate limited) instead of giving up on the first error,
        while keeping the round-robin starting point consistent across calls.
        """
        keys = provider_settings.get(self.setting_name, [])
        if not keys:
            raise ValueError(
                f"Error: {self.provider_name} API key is not configured in AstrBot."
            )
        async with self.lock:
            start = self.index % len(keys)
            self.index = (self.index + 1) % len(keys)
        return keys[start:] + keys[:start]

    async def execute_with_fallback(self, provider_settings: dict, func):
        """Execute a function with fallback to the next keys if it fails."""
        last_error = Exception(
            f"Error: {self.provider_name} API key is not configured in AstrBot."
        )
        for key in await self.ordered_keys(provider_settings):
            try:
                return await func(key)
            except Exception as e:
                last_error = e
                logger.warning(f"{self.provider_name} key failed, trying the next one: {e}")
        raise last_error
References
  1. When implementing similar functionality for different cases, refactor the logic into a shared helper function to avoid code duplication.


async def execute_with_fallback(
self,
provider_settings: dict,
fn: Callable[[str], Awaitable[_T]],
) -> _T:
"""Run ``fn(key)`` against each key in rotation order, returning the
first success and only raising once every key has failed.

Other web-search providers can reuse this to gain the same key
failover behaviour.
"""
last_error: Exception | None = None
for key in await self.ordered_keys(provider_settings):
try:
return await fn(key)
except Exception as e:
last_error = e
logger.warning(
f"{self.provider_name} key failed, trying the next one: {e}"
)
raise last_error or RuntimeError(
f"{self.provider_name} web search failed."
)


_TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily")
_BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha")
Expand Down Expand Up @@ -153,32 +198,34 @@ async def _tavily_search(
provider_settings: dict,
payload: dict,
) -> list[SearchResult]:
tavily_key = await _TAVILY_KEY_ROTATOR.get(provider_settings)
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
"https://api.tavily.com/search",
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Tavily web search failed: {reason}, status: {response.status}",
)
data = await response.json()
return [
SearchResult(
title=item.get("title"),
url=item.get("url"),
snippet=item.get("content"),
favicon=item.get("favicon"),
)
for item in data.get("results", [])
]
async def _search(tavily_key: str) -> list[SearchResult]:
header = {
"Authorization": f"Bearer {tavily_key}",
"Content-Type": "application/json",
}
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
"https://api.tavily.com/search",
json=payload,
headers=header,
) as response:
if response.status != 200:
reason = await response.text()
raise Exception(
f"Tavily web search failed: {reason}, status: {response.status}",
)
data = await response.json()
return [
SearchResult(
title=item.get("title"),
url=item.get("url"),
snippet=item.get("content"),
favicon=item.get("favicon"),
)
for item in data.get("results", [])
]

return await _TAVILY_KEY_ROTATOR.execute_with_fallback(provider_settings, _search)


async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]:
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_web_search_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,45 @@ def fake_client_session(*, trust_env):
{"websearch_exa_key": ["exa-key"]},
{"ids": ["https://example.com"]},
)


@pytest.mark.asyncio
async def test_tavily_search_falls_back_to_next_key_on_failure(monkeypatch):
responses = [
_FakeFirecrawlResponse(status=401, text_data="Unauthorized"),
_FakeFirecrawlResponse(
status=200,
json_data={
"results": [
{
"title": "AstrBot",
"url": "https://example.com",
"content": "AI Agent Assistant",
}
]
},
),
]
sessions = []

def fake_client_session(*, trust_env):
session = _FakeFirecrawlSession(responses.pop(0))
session.trust_env = trust_env
sessions.append(session)
return session

monkeypatch.setattr(tools.aiohttp, "ClientSession", fake_client_session)
tools._TAVILY_KEY_ROTATOR.index = 0

results = await tools._tavily_search(
{"websearch_tavily_key": ["bad-key", "good-key"]},
{"query": "AstrBot"},
)

assert sessions[0].posted["headers"]["Authorization"] == "Bearer bad-key"
assert sessions[1].posted["headers"]["Authorization"] == "Bearer good-key"
assert results == [
tools.SearchResult(
title="AstrBot", url="https://example.com", snippet="AI Agent Assistant"
)
]