-
Notifications
You must be signed in to change notification settings - Fork 114
fix: tolerate recoverable cloud API failures in APIBasedLLM #378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| import logging | ||||||
| import os | ||||||
| import time | ||||||
|
|
||||||
|
|
@@ -20,6 +21,14 @@ | |||||
| from models.base_llm import BaseLLM | ||||||
| from retry import retry | ||||||
|
|
||||||
| LOGGER = logging.getLogger(__name__) | ||||||
| RECOVERABLE_API_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504} | ||||||
| CONTENT_FILTER_ERROR_TYPES = { | ||||||
| "content_policy_violation_error", | ||||||
| "content_filter", | ||||||
| "prompt_rejected", | ||||||
| } | ||||||
|
|
||||||
| class APIBasedLLM(BaseLLM): | ||||||
| def __init__(self, **kwargs) -> None: | ||||||
| """ Initialize the APIBasedLLM class | ||||||
|
|
@@ -62,6 +71,37 @@ def _load(self, model): | |||||
|
|
||||||
| self.model = model | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _extract_error_details(error): | ||||||
| """Extract provider error details in a sdk-agnostic format.""" | ||||||
| status_code = getattr(error, "status_code", None) | ||||||
| body = getattr(error, "body", None) | ||||||
| error_payload = body.get("error", {}) if isinstance(body, dict) else {} | ||||||
| error_type = error_payload.get("type") | ||||||
| message = error_payload.get("message") or str(error) | ||||||
| return status_code, error_type, message | ||||||
|
|
||||||
| @classmethod | ||||||
| def _is_recoverable_api_error(cls, error): | ||||||
| """Return whether the provider error should fail only this sample.""" | ||||||
| status_code, error_type, _ = cls._extract_error_details(error) | ||||||
|
|
||||||
| if status_code in RECOVERABLE_API_STATUS_CODES: | ||||||
| return True | ||||||
|
|
||||||
| return status_code == 400 and error_type in CONTENT_FILTER_ERROR_TYPES | ||||||
|
|
||||||
| def _build_error_response(self, error): | ||||||
| """Return an empty response so the benchmark can continue.""" | ||||||
| status_code, error_type, message = self._extract_error_details(error) | ||||||
| response = self._format_response("", 0, 0, 0, 0, 0) | ||||||
| response["error"] = { | ||||||
| "status_code": status_code, | ||||||
| "type": error_type or error.__class__.__name__, | ||||||
| "message": message, | ||||||
| } | ||||||
| return response | ||||||
|
|
||||||
| @retry(tries=3, delay=4, max_delay=10) | ||||||
| def _infer(self, messages): | ||||||
| """Call the OpenAI API to get the response | ||||||
|
|
@@ -139,6 +179,17 @@ def _infer(self, messages): | |||||
| throughput = 0 | ||||||
|
|
||||||
| except Exception as e: | ||||||
| if self._is_recoverable_api_error(e): | ||||||
| status_code, error_type, message = self._extract_error_details(e) | ||||||
| LOGGER.warning( | ||||||
| "Provider inference failed for one sample; returning empty response. " | ||||||
| "provider=%s status=%s type=%s message=%s", | ||||||
| self.provider, | ||||||
| status_code, | ||||||
| error_type, | ||||||
| message, | ||||||
| ) | ||||||
| return self._build_error_response(e) | ||||||
| raise RuntimeError(f"Error during API inference: {e}") | ||||||
|
||||||
| raise RuntimeError(f"Error during API inference: {e}") | |
| raise RuntimeError(f"Error during API inference: {e}") from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also handled in 118baf1: the fail-fast path now raises RuntimeError(... ) from e, and the unit test covers the non-retry invalid-request case.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| import importlib.util | ||
| import sys | ||
| import types | ||
| import unittest | ||
| from pathlib import Path | ||
|
|
||
| MODELS_DIR = Path(__file__).resolve().parent | ||
|
|
||
|
|
||
| def load_module(name, path): | ||
| spec = importlib.util.spec_from_file_location(name, path) | ||
| module = importlib.util.module_from_spec(spec) | ||
| sys.modules[name] = module | ||
| spec.loader.exec_module(module) | ||
| return module | ||
|
|
||
|
|
||
| models_pkg = types.ModuleType("models") | ||
| models_pkg.__path__ = [str(MODELS_DIR)] | ||
| sys.modules.setdefault("models", models_pkg) | ||
|
|
||
| load_module("models.base_llm", MODELS_DIR / "base_llm.py") | ||
| api_llm = load_module("models.api_llm", MODELS_DIR / "api_llm.py") | ||
| APIBasedLLM = api_llm.APIBasedLLM | ||
|
|
||
|
|
||
| def build_api_error(status_code, error_type, message): | ||
| class FakeAPIError(Exception): | ||
| __module__ = "openai._exceptions" | ||
|
|
||
| def __init__(self): | ||
| super().__init__(message) | ||
| self.status_code = status_code | ||
| self.body = { | ||
| "error": { | ||
| "type": error_type, | ||
| "message": message, | ||
| } | ||
| } | ||
|
|
||
| return FakeAPIError() | ||
|
|
||
|
|
||
| class FakeCompletions: | ||
| def __init__(self, error): | ||
| self.error = error | ||
|
|
||
| def create(self, **kwargs): | ||
| raise self.error | ||
|
|
||
|
|
||
| class FakeChat: | ||
| def __init__(self, error): | ||
| self.completions = FakeCompletions(error) | ||
|
|
||
|
|
||
| class FakeClient: | ||
| def __init__(self, error): | ||
| self.chat = FakeChat(error) | ||
|
|
||
|
|
||
| class APIBasedLLMTests(unittest.TestCase): | ||
| @staticmethod | ||
| def build_model(error): | ||
| model = object.__new__(APIBasedLLM) | ||
| model.provider = "openai" | ||
| model.client = FakeClient(error) | ||
| model.model = "gpt-4o-mini" | ||
| model.model_name = "gpt-4o-mini" | ||
| model.temperature = 0.8 | ||
| model.max_tokens = 64 | ||
| model.top_p = 0.8 | ||
| model.repetition_penalty = 1.05 | ||
| model.use_cache = False | ||
| model.model_loaded = True | ||
| model.config = {} | ||
| model.is_cache_loaded = True | ||
| return model | ||
|
|
||
| def test_content_policy_failures_return_empty_prediction(self): | ||
| model = self.build_model( | ||
| build_api_error( | ||
| 400, | ||
| "content_policy_violation_error", | ||
| "Content validation failed.", | ||
| ) | ||
| ) | ||
|
|
||
| response = model.inference({"query": "unsafe prompt"}) | ||
|
|
||
| self.assertEqual("", response["completion"]) | ||
| self.assertIsNone(response["prediction"]) | ||
| self.assertEqual(400, response["error"]["status_code"]) | ||
| self.assertEqual( | ||
| "content_policy_violation_error", | ||
| response["error"]["type"], | ||
| ) | ||
|
|
||
| def test_invalid_requests_still_fail_fast(self): | ||
| model = self.build_model( | ||
| build_api_error( | ||
| 400, | ||
| "invalid_request_error", | ||
| "The configured model does not exist.", | ||
| ) | ||
| ) | ||
|
|
||
| with self.assertRaises(RuntimeError): | ||
| model.inference({"query": "hello"}) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation inverts the expected retry logic for transient vs. permanent errors.
RECOVERABLE_API_STATUS_CODES. Because they are caught inside the_infermethod (which is decorated with@retry), the decorator will see a successful return value (the error response) and will not perform any retries. This reduces the robustness of the benchmark against temporary network or provider issues.RuntimeError. The@retrydecorator will catch this and retry the request 3 times, which is unnecessary and adds latency for errors that will never succeed on retry.Consider refactoring to ensure transient errors are retried before being tolerated, and that permanent errors fail fast. A common pattern is to use a wrapper method for the toleration logic while keeping the retry logic on the actual API call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated in 118baf1: transient 408/429/5xx failures are now retried before being downgraded to an empty per-sample response, while content-filter 400s are still tolerated immediately and invalid requests fail fast.