Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/lightspeed_evaluation/core/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
logger = logging.getLogger(__name__)


def _is_too_many_requests_error(exception: BaseException) -> bool:
"""Check if exception is a 429 error."""
return (
isinstance(exception, httpx.HTTPStatusError)
and exception.response.status_code == 429
)
def _is_retryable_server_error(exception: BaseException) -> bool:
"""Check if exception is a non-200 HTTP response worth retrying."""
if not isinstance(exception, httpx.HTTPStatusError):
return False
return exception.response.status_code != 200


class APIClient:
Expand All @@ -59,10 +58,11 @@ def __init__(
retry_decorator = self._create_retry_decorator()
self._standard_query_with_retry = retry_decorator(self._standard_query)
self._streaming_query_with_retry = retry_decorator(self._streaming_query)
self._rlsapi_infer_query_with_retry = retry_decorator(self._rlsapi_infer_query)

def _create_retry_decorator(self) -> Any:
return retry(
retry=retry_if_exception(_is_too_many_requests_error),
retry=retry_if_exception(_is_retryable_server_error),
stop=stop_after_attempt(
self.config.num_retries + 1
), # +1 to account for the initial attempt
Expand Down Expand Up @@ -131,7 +131,7 @@ def query(
elif self.config.endpoint_type == "query":
response = self._standard_query_with_retry(api_request)
elif self.config.endpoint_type == "infer":
response = self._rlsapi_infer_query(api_request)
response = self._rlsapi_infer_query_with_retry(api_request)

if self.config.cache_enabled:
self._add_response_to_cache(api_request, response)
Expand Down Expand Up @@ -295,12 +295,12 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
if "input_tokens" in data:
response_data["input_tokens"] = data["input_tokens"]
logger.debug(
f"{RED}RLS API INPUT TOKENS: {response_data["input_tokens"]}{RESET}"
f"RLS API INPUT TOKENS: {response_data["input_tokens"]}"
)
if "output_tokens" in data:
response_data["output_tokens"] = data["output_tokens"]
logger.debug(
f"{RED}output_tokens: {response_data["output_tokens"]}{RESET}"
f"output_tokens: {response_data["output_tokens"]}"
)
if "tool_calls" in data:
response_data["tool_calls"] = data["tool_calls"]
Expand Down Expand Up @@ -366,6 +366,8 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
logger.error(
f"RLS API HTTP error - Status: {e.response.status_code}, Body: {e.response.text}"
)
if _is_retryable_server_error(e):
raise
raise self._handle_http_error(e) from e
except ValueError as e:
raise self._handle_validation_error(e) from e
Expand Down
Loading