diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7fb847dccd..9b65c08ce6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -104,6 +104,7 @@ "default_provider_id": "", "fallback_chat_models": [], "request_max_retries": 5, + "provider_error_retries": 1, "default_image_caption_provider_id": "", "image_caption_prompt": "Please describe the image using Chinese.", "provider_pool": ["*"], # "*" 表示使用所有可用的提供者 @@ -2803,6 +2804,9 @@ "request_max_retries": { "type": "int", }, + "provider_error_retries": { + "type": "int", + }, "wake_prefix": { "type": "string", }, diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index a49003af17..666cde2b73 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -398,6 +398,25 @@ def __init__(self, provider_config, provider_settings) -> None: self.reasoning_key = "reasoning_content" + def _provider_error_retries(self) -> int: + """Return retry attempts for provider-level recovery paths. + + This outer retry loop handles payload mutation and key rotation (for + example context trimming, tool removal, image fallback, or switching to + another configured API key). Transport/status-code retries are handled + separately by ``request_max_retries`` in ``retry_provider_request``. + Keeping this value configurable prevents nested retry loops from + multiplying latency for proxy/aggregator providers that already perform + their own upstream retry and fallback. + """ + provider_settings = getattr(self, "provider_settings", {}) or {} + raw = provider_settings.get("provider_error_retries", 1) + try: + retries = int(raw) + except (TypeError, ValueError): + retries = 1 + return max(1, retries) + def _ollama_disable_thinking_enabled(self) -> bool: value = self.provider_config.get("ollama_disable_thinking", False) if isinstance(value, str): @@ -1188,7 +1207,7 @@ async def text_chat( payloads["tool_choice"] = tool_choice llm_response = None - max_retries = 10 + max_retries = self._provider_error_retries() available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) image_fallback_used = False @@ -1228,7 +1247,7 @@ async def text_chat( if success: break - if retry_cnt == max_retries - 1 or llm_response is None: + if llm_response is None: logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") if last_exception is None: raise Exception("未知错误") @@ -1264,13 +1283,14 @@ async def text_chat_stream( if func_tool and not func_tool.empty(): payloads["tool_choice"] = tool_choice - max_retries = 10 + max_retries = self._provider_error_retries() available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) image_fallback_used = False last_exception = None retry_cnt = 0 + completed = False for retry_cnt in range(max_retries): try: self.client.api_key = chosen_key @@ -1280,6 +1300,7 @@ async def text_chat_stream( request_max_retries=request_max_retries, ): yield response + completed = True break except Exception as e: last_exception = e @@ -1305,7 +1326,7 @@ async def text_chat_stream( if success: break - if retry_cnt == max_retries - 1: + if not completed: logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") if last_exception is None: raise Exception("未知错误") diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index b8262090e4..73332b1070 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -120,6 +120,24 @@ def fake_import(name, globals=None, locals=None, fromlist=(), level=0): assert captured["httpx_module"] is openai_source_module.httpx +def test_provider_error_retries_defaults_and_coerces_values(): + provider = ProviderOpenAIOfficial.__new__(ProviderOpenAIOfficial) + + assert provider._provider_error_retries() == 1 + + provider.provider_settings = {} + assert provider._provider_error_retries() == 1 + + provider.provider_settings = {"provider_error_retries": "3"} + assert provider._provider_error_retries() == 3 + + provider.provider_settings = {"provider_error_retries": 0} + assert provider._provider_error_retries() == 1 + + provider.provider_settings = {"provider_error_retries": "invalid"} + assert provider._provider_error_retries() == 1 + + @pytest.mark.asyncio async def test_get_models_retries_transient_request_error(monkeypatch): monkeypatch.setattr(request_retry, "REQUEST_RETRY_WAIT_MIN_S", 0)