diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index d57cd4c8..7edc0aff 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -24,6 +24,7 @@ BadRequestError, NotFoundError, APIConnectionError, + APIError ) from litellm.types.utils import Usage import instructor @@ -58,8 +59,10 @@ from tenacity import ( AsyncRetrying, Retrying, + retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt, + wait_fixed, wait_random_exponential, ) from pydantic_core._pydantic_core import ValidationError @@ -91,8 +94,12 @@ # wait=wait_random_exponential(multiplier=1, max=60), # ) -# For now, disabling all instructor retries as of DIA-1910 to speed up inference runs greatly -RETRY_POLICY = dict(stop=stop_after_attempt(1)) +# DIA-1910 disabled all retries - DIA-2083 introduces retries on APIError caused by connection error +RETRY_POLICY = dict( + retry=retry_if_exception_type(APIError), + stop=stop_after_attempt(3), + wait=wait_fixed(1), +) retries = Retrying(**RETRY_POLICY) async_retries = AsyncRetrying(**RETRY_POLICY) diff --git a/adala/utils/llm_utils.py b/adala/utils/llm_utils.py index 6329674a..b9383722 100644 --- a/adala/utils/llm_utils.py +++ b/adala/utils/llm_utils.py @@ -103,7 +103,10 @@ def handle_llm_exception( n_attempts = retries.stop.max_attempt_number if prompt_token_count is None: prompt_token_count = token_counter(model=model, messages=messages[:-1]) - prompt_tokens = n_attempts * prompt_token_count + if type(e).__name__ in {"APIError", "AuthenticationError", "APIConnectionError"}: + prompt_tokens = 0 + else: + prompt_tokens = n_attempts * prompt_token_count # TODO a pydantic validation error may be appended as the last message, don't know how to get the raw response in this case usage = Usage( prompt_tokens=prompt_tokens,