Skip to content

Commit 324b61b

Browse files
committed
refactor: extract reusable execute_with_fallback on _KeyRotator
Address review feedback: move the per-key failover loop into a reusable _KeyRotator.execute_with_fallback so other web-search providers can adopt the same behaviour, and simplify the error handling.
1 parent 2be5782 commit 324b61b

1 file changed

Lines changed: 52 additions & 30 deletions

File tree

astrbot/core/tools/web_search_tools.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uuid
44
from dataclasses import dataclass as std_dataclass
55
from dataclasses import field
6+
from typing import Awaitable, Callable, TypeVar
67

78
import aiohttp
89
from pydantic import Field
@@ -58,6 +59,9 @@ class SearchResult:
5859
favicon: str | None = None
5960

6061

62+
_T = TypeVar("_T")
63+
64+
6165
@std_dataclass
6266
class _KeyRotator:
6367
setting_name: str
@@ -96,6 +100,30 @@ async def ordered_keys(self, provider_settings: dict) -> list[str]:
96100
self.index = (self.index + 1) % len(keys)
97101
return keys[start:] + keys[:start]
98102

103+
async def execute_with_fallback(
104+
self,
105+
provider_settings: dict,
106+
fn: Callable[[str], Awaitable[_T]],
107+
) -> _T:
108+
"""Run ``fn(key)`` against each key in rotation order, returning the
109+
first success and only raising once every key has failed.
110+
111+
Other web-search providers can reuse this to gain the same key
112+
failover behaviour.
113+
"""
114+
last_error: Exception | None = None
115+
for key in await self.ordered_keys(provider_settings):
116+
try:
117+
return await fn(key)
118+
except Exception as e:
119+
last_error = e
120+
logger.warning(
121+
f"{self.provider_name} key failed, trying the next one: {e}"
122+
)
123+
raise last_error or RuntimeError(
124+
f"{self.provider_name} web search failed."
125+
)
126+
99127

100128
_TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily")
101129
_BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha")
@@ -170,40 +198,34 @@ async def _tavily_search(
170198
provider_settings: dict,
171199
payload: dict,
172200
) -> list[SearchResult]:
173-
last_error: Exception = Exception(
174-
"Error: Tavily API key is not configured in AstrBot."
175-
)
176-
for tavily_key in await _TAVILY_KEY_ROTATOR.ordered_keys(provider_settings):
201+
async def _search(tavily_key: str) -> list[SearchResult]:
177202
header = {
178203
"Authorization": f"Bearer {tavily_key}",
179204
"Content-Type": "application/json",
180205
}
181-
try:
182-
async with aiohttp.ClientSession(trust_env=True) as session:
183-
async with session.post(
184-
"https://api.tavily.com/search",
185-
json=payload,
186-
headers=header,
187-
) as response:
188-
if response.status != 200:
189-
reason = await response.text()
190-
raise Exception(
191-
f"Tavily web search failed: {reason}, status: {response.status}",
192-
)
193-
data = await response.json()
194-
return [
195-
SearchResult(
196-
title=item.get("title"),
197-
url=item.get("url"),
198-
snippet=item.get("content"),
199-
favicon=item.get("favicon"),
200-
)
201-
for item in data.get("results", [])
202-
]
203-
except Exception as e:
204-
last_error = e
205-
logger.warning(f"Tavily key failed, trying the next one: {e}")
206-
raise last_error
206+
async with aiohttp.ClientSession(trust_env=True) as session:
207+
async with session.post(
208+
"https://api.tavily.com/search",
209+
json=payload,
210+
headers=header,
211+
) as response:
212+
if response.status != 200:
213+
reason = await response.text()
214+
raise Exception(
215+
f"Tavily web search failed: {reason}, status: {response.status}",
216+
)
217+
data = await response.json()
218+
return [
219+
SearchResult(
220+
title=item.get("title"),
221+
url=item.get("url"),
222+
snippet=item.get("content"),
223+
favicon=item.get("favicon"),
224+
)
225+
for item in data.get("results", [])
226+
]
227+
228+
return await _TAVILY_KEY_ROTATOR.execute_with_fallback(provider_settings, _search)
207229

208230

209231
async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]:

0 commit comments

Comments
 (0)