Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import time

Expand All @@ -20,6 +21,14 @@
from models.base_llm import BaseLLM
from retry import retry

LOGGER = logging.getLogger(__name__)
RECOVERABLE_API_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504}
CONTENT_FILTER_ERROR_TYPES = {
"content_policy_violation_error",
"content_filter",
"prompt_rejected",
}

class APIBasedLLM(BaseLLM):
def __init__(self, **kwargs) -> None:
""" Initialize the APIBasedLLM class
Expand Down Expand Up @@ -62,6 +71,37 @@ def _load(self, model):

self.model = model

@staticmethod
def _extract_error_details(error):
"""Extract provider error details in a sdk-agnostic format."""
status_code = getattr(error, "status_code", None)
body = getattr(error, "body", None)
error_payload = body.get("error", {}) if isinstance(body, dict) else {}
error_type = error_payload.get("type")
message = error_payload.get("message") or str(error)
return status_code, error_type, message

@classmethod
def _is_recoverable_api_error(cls, error):
"""Return whether the provider error should fail only this sample."""
status_code, error_type, _ = cls._extract_error_details(error)

if status_code in RECOVERABLE_API_STATUS_CODES:
return True

return status_code == 400 and error_type in CONTENT_FILTER_ERROR_TYPES

def _build_error_response(self, error):
"""Return an empty response so the benchmark can continue."""
status_code, error_type, message = self._extract_error_details(error)
response = self._format_response("", 0, 0, 0, 0, 0)
response["error"] = {
"status_code": status_code,
"type": error_type or error.__class__.__name__,
"message": message,
}
return response

@retry(tries=3, delay=4, max_delay=10)
def _infer(self, messages):
"""Call the OpenAI API to get the response
Expand Down Expand Up @@ -139,6 +179,17 @@ def _infer(self, messages):
throughput = 0

except Exception as e:
if self._is_recoverable_api_error(e):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation inverts the expected retry logic for transient vs. permanent errors.

  1. Transient Errors (e.g., 503, 429): These are included in RECOVERABLE_API_STATUS_CODES. Because they are caught inside the _infer method (which is decorated with @retry), the decorator will see a successful return value (the error response) and will not perform any retries. This reduces the robustness of the benchmark against temporary network or provider issues.
  2. Permanent Errors (e.g., 401, 404): These are not in the recoverable set, so they fall through to line 193 and raise a RuntimeError. The @retry decorator will catch this and retry the request 3 times, which is unnecessary and adds latency for errors that will never succeed on retry.

Consider refactoring to ensure transient errors are retried before being tolerated, and that permanent errors fail fast. A common pattern is to use a wrapper method for the toleration logic while keeping the retry logic on the actual API call.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 118baf1: transient 408/429/5xx failures are now retried before being downgraded to an empty per-sample response, while content-filter 400s are still tolerated immediately and invalid requests fail fast.

status_code, error_type, message = self._extract_error_details(e)
LOGGER.warning(
"Provider inference failed for one sample; returning empty response. "
"provider=%s status=%s type=%s message=%s",
self.provider,
status_code,
error_type,
message,
)
return self._build_error_response(e)
raise RuntimeError(f"Error during API inference: {e}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When re-raising the exception as a RuntimeError, it is best practice to use from e to preserve the original exception's traceback. This is crucial for debugging the root cause of the failure.

Suggested change
raise RuntimeError(f"Error during API inference: {e}")
raise RuntimeError(f"Error during API inference: {e}") from e

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also handled in 118baf1: the fail-fast path now raises RuntimeError(... ) from e, and the unit test covers the non-retry invalid-request case.


response = self._format_response(
Expand All @@ -150,4 +201,4 @@ def _infer(self, messages):
throughput
)

return response
return response
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import importlib.util
import sys
import types
import unittest
from pathlib import Path

MODELS_DIR = Path(__file__).resolve().parent


def load_module(name, path):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module


models_pkg = types.ModuleType("models")
models_pkg.__path__ = [str(MODELS_DIR)]
sys.modules.setdefault("models", models_pkg)

load_module("models.base_llm", MODELS_DIR / "base_llm.py")
api_llm = load_module("models.api_llm", MODELS_DIR / "api_llm.py")
APIBasedLLM = api_llm.APIBasedLLM


def build_api_error(status_code, error_type, message):
class FakeAPIError(Exception):
__module__ = "openai._exceptions"

def __init__(self):
super().__init__(message)
self.status_code = status_code
self.body = {
"error": {
"type": error_type,
"message": message,
}
}

return FakeAPIError()


class FakeCompletions:
def __init__(self, error):
self.error = error

def create(self, **kwargs):
raise self.error


class FakeChat:
def __init__(self, error):
self.completions = FakeCompletions(error)


class FakeClient:
def __init__(self, error):
self.chat = FakeChat(error)


class APIBasedLLMTests(unittest.TestCase):
@staticmethod
def build_model(error):
model = object.__new__(APIBasedLLM)
model.provider = "openai"
model.client = FakeClient(error)
model.model = "gpt-4o-mini"
model.model_name = "gpt-4o-mini"
model.temperature = 0.8
model.max_tokens = 64
model.top_p = 0.8
model.repetition_penalty = 1.05
model.use_cache = False
model.model_loaded = True
model.config = {}
model.is_cache_loaded = True
return model

def test_content_policy_failures_return_empty_prediction(self):
model = self.build_model(
build_api_error(
400,
"content_policy_violation_error",
"Content validation failed.",
)
)

response = model.inference({"query": "unsafe prompt"})

self.assertEqual("", response["completion"])
self.assertIsNone(response["prediction"])
self.assertEqual(400, response["error"]["status_code"])
self.assertEqual(
"content_policy_violation_error",
response["error"]["type"],
)

def test_invalid_requests_still_fail_fast(self):
model = self.build_model(
build_api_error(
400,
"invalid_request_error",
"The configured model does not exist.",
)
)

with self.assertRaises(RuntimeError):
model.inference({"query": "hello"})


if __name__ == "__main__":
unittest.main()