|
3 | 3 | import uuid |
4 | 4 | from dataclasses import dataclass as std_dataclass |
5 | 5 | from dataclasses import field |
| 6 | +from typing import Awaitable, Callable, TypeVar |
6 | 7 |
|
7 | 8 | import aiohttp |
8 | 9 | from pydantic import Field |
@@ -58,6 +59,9 @@ class SearchResult: |
58 | 59 | favicon: str | None = None |
59 | 60 |
|
60 | 61 |
|
| 62 | +_T = TypeVar("_T") |
| 63 | + |
| 64 | + |
61 | 65 | @std_dataclass |
62 | 66 | class _KeyRotator: |
63 | 67 | setting_name: str |
@@ -96,6 +100,30 @@ async def ordered_keys(self, provider_settings: dict) -> list[str]: |
96 | 100 | self.index = (self.index + 1) % len(keys) |
97 | 101 | return keys[start:] + keys[:start] |
98 | 102 |
|
| 103 | + async def execute_with_fallback( |
| 104 | + self, |
| 105 | + provider_settings: dict, |
| 106 | + fn: Callable[[str], Awaitable[_T]], |
| 107 | + ) -> _T: |
| 108 | + """Run ``fn(key)`` against each key in rotation order, returning the |
| 109 | + first success and only raising once every key has failed. |
| 110 | +
|
| 111 | + Other web-search providers can reuse this to gain the same key |
| 112 | + failover behaviour. |
| 113 | + """ |
| 114 | + last_error: Exception | None = None |
| 115 | + for key in await self.ordered_keys(provider_settings): |
| 116 | + try: |
| 117 | + return await fn(key) |
| 118 | + except Exception as e: |
| 119 | + last_error = e |
| 120 | + logger.warning( |
| 121 | + f"{self.provider_name} key failed, trying the next one: {e}" |
| 122 | + ) |
| 123 | + raise last_error or RuntimeError( |
| 124 | + f"{self.provider_name} web search failed." |
| 125 | + ) |
| 126 | + |
99 | 127 |
|
100 | 128 | _TAVILY_KEY_ROTATOR = _KeyRotator("websearch_tavily_key", "Tavily") |
101 | 129 | _BOCHA_KEY_ROTATOR = _KeyRotator("websearch_bocha_key", "BoCha") |
@@ -170,40 +198,34 @@ async def _tavily_search( |
170 | 198 | provider_settings: dict, |
171 | 199 | payload: dict, |
172 | 200 | ) -> list[SearchResult]: |
173 | | - last_error: Exception = Exception( |
174 | | - "Error: Tavily API key is not configured in AstrBot." |
175 | | - ) |
176 | | - for tavily_key in await _TAVILY_KEY_ROTATOR.ordered_keys(provider_settings): |
| 201 | + async def _search(tavily_key: str) -> list[SearchResult]: |
177 | 202 | header = { |
178 | 203 | "Authorization": f"Bearer {tavily_key}", |
179 | 204 | "Content-Type": "application/json", |
180 | 205 | } |
181 | | - try: |
182 | | - async with aiohttp.ClientSession(trust_env=True) as session: |
183 | | - async with session.post( |
184 | | - "https://api.tavily.com/search", |
185 | | - json=payload, |
186 | | - headers=header, |
187 | | - ) as response: |
188 | | - if response.status != 200: |
189 | | - reason = await response.text() |
190 | | - raise Exception( |
191 | | - f"Tavily web search failed: {reason}, status: {response.status}", |
192 | | - ) |
193 | | - data = await response.json() |
194 | | - return [ |
195 | | - SearchResult( |
196 | | - title=item.get("title"), |
197 | | - url=item.get("url"), |
198 | | - snippet=item.get("content"), |
199 | | - favicon=item.get("favicon"), |
200 | | - ) |
201 | | - for item in data.get("results", []) |
202 | | - ] |
203 | | - except Exception as e: |
204 | | - last_error = e |
205 | | - logger.warning(f"Tavily key failed, trying the next one: {e}") |
206 | | - raise last_error |
| 206 | + async with aiohttp.ClientSession(trust_env=True) as session: |
| 207 | + async with session.post( |
| 208 | + "https://api.tavily.com/search", |
| 209 | + json=payload, |
| 210 | + headers=header, |
| 211 | + ) as response: |
| 212 | + if response.status != 200: |
| 213 | + reason = await response.text() |
| 214 | + raise Exception( |
| 215 | + f"Tavily web search failed: {reason}, status: {response.status}", |
| 216 | + ) |
| 217 | + data = await response.json() |
| 218 | + return [ |
| 219 | + SearchResult( |
| 220 | + title=item.get("title"), |
| 221 | + url=item.get("url"), |
| 222 | + snippet=item.get("content"), |
| 223 | + favicon=item.get("favicon"), |
| 224 | + ) |
| 225 | + for item in data.get("results", []) |
| 226 | + ] |
| 227 | + |
| 228 | + return await _TAVILY_KEY_ROTATOR.execute_with_fallback(provider_settings, _search) |
207 | 229 |
|
208 | 230 |
|
209 | 231 | async def _tavily_extract(provider_settings: dict, payload: dict) -> list[dict]: |
|
0 commit comments