2727logger = logging .getLogger (__name__ )
2828
2929
30- def _is_too_many_requests_error (exception : BaseException ) -> bool :
31- """Check if exception is a 429 error."""
32- return (
33- isinstance (exception , httpx .HTTPStatusError )
34- and exception .response .status_code == 429
35- )
30+ def _is_retryable_server_error (exception : BaseException ) -> bool :
31+ """Check if exception is a non-200 HTTP response worth retrying."""
32+ if not isinstance (exception , httpx .HTTPStatusError ):
33+ return False
34+ return exception .response .status_code != 200
3635
3736
3837class APIClient :
@@ -59,10 +58,11 @@ def __init__(
5958 retry_decorator = self ._create_retry_decorator ()
6059 self ._standard_query_with_retry = retry_decorator (self ._standard_query )
6160 self ._streaming_query_with_retry = retry_decorator (self ._streaming_query )
61+ self ._rlsapi_infer_query_with_retry = retry_decorator (self ._rlsapi_infer_query )
6262
6363 def _create_retry_decorator (self ) -> Any :
6464 return retry (
65- retry = retry_if_exception (_is_too_many_requests_error ),
65+ retry = retry_if_exception (_is_retryable_server_error ),
6666 stop = stop_after_attempt (
6767 self .config .num_retries + 1
6868 ), # +1 to account for the initial attempt
@@ -131,7 +131,7 @@ def query(
131131 elif self .config .endpoint_type == "query" :
132132 response = self ._standard_query_with_retry (api_request )
133133 elif self .config .endpoint_type == "infer" :
134- response = self ._rlsapi_infer_query (api_request )
134+ response = self ._rlsapi_infer_query_with_retry (api_request )
135135
136136 if self .config .cache_enabled :
137137 self ._add_response_to_cache (api_request , response )
@@ -295,12 +295,12 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
295295 if "input_tokens" in data :
296296 response_data ["input_tokens" ] = data ["input_tokens" ]
297297 logger .debug (
298- f"{ RED } RLS API INPUT TOKENS: { response_data ["input_tokens" ]} { RESET } "
298+ f"RLS API INPUT TOKENS: { response_data ["input_tokens" ]} "
299299 )
300300 if "output_tokens" in data :
301301 response_data ["output_tokens" ] = data ["output_tokens" ]
302302 logger .debug (
303- f"{ RED } output_tokens: { response_data ["output_tokens" ]} { RESET } "
303+ f"output_tokens: { response_data ["output_tokens" ]} "
304304 )
305305 if "tool_calls" in data :
306306 response_data ["tool_calls" ] = data ["tool_calls" ]
@@ -366,6 +366,8 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
366366 logger .error (
367367 f"RLS API HTTP error - Status: { e .response .status_code } , Body: { e .response .text } "
368368 )
369+ if _is_retryable_server_error (e ):
370+ raise
369371 raise self ._handle_http_error (e ) from e
370372 except ValueError as e :
371373 raise self ._handle_validation_error (e ) from e
0 commit comments