Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import time

from openai import OpenAI
from groq import Groq
from models.base_llm import BaseLLM
from retry import retry

LOGGER = logging.getLogger(__name__)
RETRYABLE_API_STATUS_CODES = {408, 409, 429, 500, 502, 503, 504}
CONTENT_FILTER_ERROR_TYPES = {
"content_policy_violation_error",
"content_filter",
"prompt_rejected",
}
MAX_API_RETRY_ATTEMPTS = 3
API_RETRY_DELAY_SECONDS = 4


class APIBasedLLM(BaseLLM):
def __init__(self, **kwargs) -> None:
Expand Down Expand Up @@ -62,9 +73,71 @@ def _load(self, model):

self.model = model

@retry(tries=3, delay=4, max_delay=10)
def _infer(self, messages):
"""Call the OpenAI API to get the response
@staticmethod
def _extract_error_details(error):
"""Extract provider error details in a sdk-agnostic format."""
status_code = getattr(error, "status_code", None)
body = getattr(error, "body", None)
error_payload = body.get("error", {}) if isinstance(body, dict) else {}
error_type = error_payload.get("type")
message = error_payload.get("message") or str(error)
return status_code, error_type, message

@classmethod
def _is_retryable_api_error(cls, error):
"""Return whether the provider error should be retried."""
status_code, _, _ = cls._extract_error_details(error)

return status_code in RETRYABLE_API_STATUS_CODES

@classmethod
def _is_tolerated_api_error(cls, error):
"""Return whether the provider error should fail only this sample."""
status_code, error_type, _ = cls._extract_error_details(error)

return cls._is_retryable_api_error(error) or (
status_code == 400 and error_type in CONTENT_FILTER_ERROR_TYPES
)

def _build_error_response(self, error):
"""Return an empty response so the benchmark can continue."""
status_code, error_type, message = self._extract_error_details(error)
response = self._format_response("", 0, 0, 0, 0, 0)
response["error"] = {
"status_code": status_code,
"type": error_type or error.__class__.__name__,
"message": message,
}
return response

def _create_stream(self, messages):
"""Create a streaming completion request for the configured provider."""
if self.provider == "openai":
return self.client.chat.completions.create(
messages=messages,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.repetition_penalty,
stream=True,
stream_options={"include_usage": True}
)
if self.provider == "groq":
return self.client.chat.completions.create(
messages=messages,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.repetition_penalty,
stream=True
)

raise ValueError(f"Unsupported provider: {self.provider}")

def _infer_once(self, messages):
"""Call the provider API once and format the response.

Parameters
----------
Expand All @@ -85,61 +158,35 @@ def _infer(self, messages):
st = time.perf_counter()
most_recent_timestamp = st
generated_text = ""
try:
if self.provider == "openai":
stream = self.client.chat.completions.create(
messages=messages,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.repetition_penalty,
stream=True,
stream_options={"include_usage": True}
)
elif self.provider == "groq":
stream = self.client.chat.completions.create(
messages=messages,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
frequency_penalty=self.repetition_penalty,
stream=True
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")

for chunk in stream:
timestamp = time.perf_counter()
if time_to_first_token == 0.0:
time_to_first_token = timestamp - st
else:
internal_token_latency.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp

if chunk.choices:
generated_text += chunk.choices[0].delta.content or ""
if self.provider == "openai" and chunk.usage:
usage = chunk.usage

text = generated_text
if self.provider == "openai":
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
else:
prompt_tokens = len(messages[0]['content'].split()) # Approximate
completion_tokens = len(text.split()) # Approximate
stream = self._create_stream(messages)

if internal_token_latency:
internal_token_latency = sum(internal_token_latency) / len(internal_token_latency)
throughput = 1 / internal_token_latency
for chunk in stream:
timestamp = time.perf_counter()
if time_to_first_token == 0.0:
time_to_first_token = timestamp - st
else:
internal_token_latency = 0
throughput = 0
internal_token_latency.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp

if chunk.choices:
generated_text += chunk.choices[0].delta.content or ""
if self.provider == "openai" and chunk.usage:
usage = chunk.usage

text = generated_text
if self.provider == "openai":
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
else:
prompt_tokens = len(messages[0]['content'].split()) # Approximate
completion_tokens = len(text.split()) # Approximate

except Exception as e:
raise RuntimeError(f"Error during API inference: {e}")
if internal_token_latency:
internal_token_latency = sum(internal_token_latency) / len(internal_token_latency)
throughput = 1 / internal_token_latency
else:
internal_token_latency = 0
throughput = 0

response = self._format_response(
text,
Expand All @@ -150,4 +197,28 @@ def _infer(self, messages):
throughput
)

return response
return response

def _infer(self, messages):
"""Call the provider API, retrying transient errors before tolerating them."""
for attempt in range(1, MAX_API_RETRY_ATTEMPTS + 1):
try:
return self._infer_once(messages)
except Exception as e:
if self._is_retryable_api_error(e):
if attempt < MAX_API_RETRY_ATTEMPTS:
time.sleep(API_RETRY_DELAY_SECONDS)
continue
elif not self._is_tolerated_api_error(e):
raise RuntimeError(f"Error during API inference: {e}") from e

status_code, error_type, message = self._extract_error_details(e)
LOGGER.warning(
"Provider inference failed for one sample; returning empty response. "
"provider=%s status=%s type=%s message=%s",
self.provider,
status_code,
error_type,
message,
)
return self._build_error_response(e)
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import importlib.util
import sys
import types
import unittest
from pathlib import Path
from unittest import mock

MODELS_DIR = Path(__file__).resolve().parent


def load_module(name, path):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module


models_pkg = types.ModuleType("models")
models_pkg.__path__ = [str(MODELS_DIR)]
sys.modules.setdefault("models", models_pkg)

openai_pkg = types.ModuleType("openai")
openai_pkg.OpenAI = object
sys.modules.setdefault("openai", openai_pkg)

groq_pkg = types.ModuleType("groq")
groq_pkg.Groq = object
sys.modules.setdefault("groq", groq_pkg)

load_module("models.base_llm", MODELS_DIR / "base_llm.py")
api_llm = load_module("models.api_llm", MODELS_DIR / "api_llm.py")
APIBasedLLM = api_llm.APIBasedLLM


def build_api_error(status_code, error_type, message):
class FakeAPIError(Exception):
__module__ = "openai._exceptions"

def __init__(self):
super().__init__(message)
self.status_code = status_code
self.body = {
"error": {
"type": error_type,
"message": message,
}
}

return FakeAPIError()


class FakeCompletions:
def __init__(self, side_effects):
self.side_effects = list(side_effects)
self.calls = 0

def create(self, **kwargs):
self.calls += 1
side_effect = self.side_effects[min(self.calls - 1, len(self.side_effects) - 1)]
raise side_effect


class FakeChat:
def __init__(self, side_effects):
self.completions = FakeCompletions(side_effects)


class FakeClient:
def __init__(self, side_effects):
self.chat = FakeChat(side_effects)


class APIBasedLLMTests(unittest.TestCase):
@staticmethod
def build_model(*side_effects):
model = object.__new__(APIBasedLLM)
model.provider = "openai"
model.client = FakeClient(side_effects)
model.model = "gpt-4o-mini"
model.model_name = "gpt-4o-mini"
model.temperature = 0.8
model.max_tokens = 64
model.top_p = 0.8
model.repetition_penalty = 1.05
model.use_cache = False
model.model_loaded = True
model.config = {}
model.is_cache_loaded = True
return model

def test_content_policy_failures_return_empty_prediction(self):
model = self.build_model(
build_api_error(
400,
"content_policy_violation_error",
"Content validation failed.",
)
)

with mock.patch.object(api_llm.time, "sleep", return_value=None):
response = model.inference({"query": "unsafe prompt"})

self.assertEqual("", response["completion"])
self.assertIsNone(response["prediction"])
self.assertEqual(400, response["error"]["status_code"])
self.assertEqual(
"content_policy_violation_error",
response["error"]["type"],
)
self.assertEqual(1, model.client.chat.completions.calls)

def test_retryable_failures_are_retried_before_returning_empty_prediction(self):
model = self.build_model(
build_api_error(503, "server_error", "Service temporarily unavailable."),
)

with mock.patch.object(api_llm.time, "sleep", return_value=None) as sleep_mock:
response = model.inference({"query": "retry me"})

self.assertEqual("", response["completion"])
self.assertIsNone(response["prediction"])
self.assertEqual(503, response["error"]["status_code"])
self.assertEqual(3, model.client.chat.completions.calls)
self.assertEqual(2, sleep_mock.call_count)

def test_invalid_requests_still_fail_fast(self):
model = self.build_model(
build_api_error(
400,
"invalid_request_error",
"The configured model does not exist.",
)
)

with self.assertRaises(RuntimeError):
with mock.patch.object(api_llm.time, "sleep", return_value=None):
model.inference({"query": "hello"})

self.assertEqual(1, model.client.chat.completions.calls)


if __name__ == "__main__":
unittest.main()