Skip to content

Commit 3667d5c

Browse files
authored
Merge pull request #2047 from BloggerBust/feat/retry-policy-all-models
feat(retry_policy): apply retry policy to all models
2 parents fa8c7b0 + 9668fe0 commit 3667d5c

28 files changed

Lines changed: 2179 additions & 452 deletions

deepeval/config/settings.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,25 @@
99
type coercion.
1010
"""
1111

12+
import logging
1213
import os
1314
import re
1415

1516
from dotenv import dotenv_values
1617
from pathlib import Path
1718
from pydantic import AnyUrl, SecretStr, field_validator, confloat
1819
from pydantic_settings import BaseSettings, SettingsConfigDict
19-
from typing import Any, Dict, Optional, NamedTuple
20+
from typing import Any, Dict, List, Optional, NamedTuple
2021

21-
from deepeval.config.utils import parse_bool
22+
from deepeval.config.utils import (
23+
parse_bool,
24+
coerce_to_list,
25+
dedupe_preserve_order,
26+
)
27+
from deepeval.constants import SUPPORTED_PROVIDER_SLUGS, slugify
2228

2329

30+
logger = logging.getLogger(__name__)
2431
_SAVE_RE = re.compile(r"^(?P<scheme>dotenv)(?::(?P<path>.+))?$")
2532

2633

@@ -264,6 +271,13 @@ class Settings(BaseSettings):
264271
LOCAL_EMBEDDING_MODEL_NAME: Optional[str] = None
265272
LOCAL_EMBEDDING_BASE_URL: Optional[AnyUrl] = None
266273

274+
#
275+
# Retry Policy
276+
#
277+
DEEPEVAL_SDK_RETRY_PROVIDERS: Optional[List[str]] = None
278+
DEEPEVAL_RETRY_BEFORE_LOG_LEVEL: Optional[int] = None # default -> INFO
279+
DEEPEVAL_RETRY_AFTER_LOG_LEVEL: Optional[int] = None # default -> ERROR
280+
267281
#
268282
# Telemetry and Debug
269283
#
@@ -283,6 +297,12 @@ class Settings(BaseSettings):
283297
CONFIDENT_SAMPLE_RATE: Optional[float] = 1.0
284298
OTEL_EXPORTER_OTLP_ENDPOINT: Optional[AnyUrl] = None
285299

300+
#
301+
# Network
302+
#
303+
MEDIA_IMAGE_CONNECT_TIMEOUT_SECONDS: float = 3.05
304+
MEDIA_IMAGE_READ_TIMEOUT_SECONDS: float = 10.0
305+
286306
##############
287307
# Validators #
288308
##############
@@ -401,6 +421,78 @@ def _normalize_upper(cls, v):
401421
return None
402422
return s.upper()
403423

424+
@field_validator("DEEPEVAL_SDK_RETRY_PROVIDERS", mode="before")
425+
@classmethod
426+
def _coerce_to_list(cls, v):
427+
# works with JSON list, comma/space/semicolon separated, or real lists
428+
return coerce_to_list(v, lower=True)
429+
430+
@field_validator("DEEPEVAL_SDK_RETRY_PROVIDERS", mode="after")
431+
@classmethod
432+
def _validate_sdk_provider_list(cls, v):
433+
if v is None:
434+
return None
435+
436+
normalized: list[str] = []
437+
star = False
438+
439+
for item in v:
440+
s = str(item).strip()
441+
if not s:
442+
continue
443+
if s == "*":
444+
star = True
445+
continue
446+
s = slugify(s)
447+
if s in SUPPORTED_PROVIDER_SLUGS:
448+
normalized.append(s)
449+
else:
450+
if cls.DEEPEVAL_VERBOSE_MODE:
451+
logger.warning("Unknown provider slug %r dropped", item)
452+
453+
if star:
454+
return ["*"]
455+
456+
# It is important to dedup after normalization to catch variants
457+
normalized = dedupe_preserve_order(normalized)
458+
return normalized or None
459+
460+
@field_validator(
461+
"DEEPEVAL_RETRY_BEFORE_LOG_LEVEL",
462+
"DEEPEVAL_RETRY_AFTER_LOG_LEVEL",
463+
mode="before",
464+
)
465+
@classmethod
466+
def _coerce_log_level(cls, v):
467+
if v is None:
468+
return None
469+
if isinstance(v, (int, float)):
470+
return int(v)
471+
472+
s = str(v).strip().upper()
473+
if not s:
474+
return None
475+
476+
import logging
477+
478+
# Accept standard names or numeric strings
479+
name_to_level = {
480+
"CRITICAL": logging.CRITICAL,
481+
"ERROR": logging.ERROR,
482+
"WARNING": logging.WARNING,
483+
"INFO": logging.INFO,
484+
"DEBUG": logging.DEBUG,
485+
"NOTSET": logging.NOTSET,
486+
}
487+
if s.isdigit() or (s.startswith("-") and s[1:].isdigit()):
488+
return int(s)
489+
if s in name_to_level:
490+
return name_to_level[s]
491+
raise ValueError(
492+
"Retry log level must be one of DEBUG, INFO, WARNING, ERROR, "
493+
"CRITICAL, NOTSET, or a numeric logging level."
494+
)
495+
404496
#######################
405497
# Persistence support #
406498
#######################

deepeval/config/utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
import json
12
import os
2-
from typing import Any, Optional
3+
import re
4+
5+
from typing import Any, Iterable, List, Optional
6+
37

48
_TRUTHY = frozenset({"1", "true", "t", "yes", "y", "on", "enable", "enabled"})
59
_FALSY = frozenset({"0", "false", "f", "no", "n", "off", "disable", "disabled"})
10+
_LIST_SEP_RE = re.compile(r"[,\s;]+")
611

712

813
def parse_bool(value: Any, default: bool = False) -> bool:
@@ -84,3 +89,51 @@ def set_env_bool(key: str, value: Optional[bool] = False) -> None:
8489
- Use `get_env_bool` to read back and parse the value safely.
8590
"""
8691
os.environ[key] = bool_to_env_str(bool(value))
92+
93+
94+
def coerce_to_list(
95+
v,
96+
*,
97+
lower: bool = False,
98+
allow_json: bool = True,
99+
sep_re: re.Pattern = _LIST_SEP_RE,
100+
) -> Optional[List[str]]:
101+
"""
102+
Coerce None / str / list / tuple / set into a clean List[str].
103+
- Accepts JSON arrays ("[...]"") or delimited strings (comma/space/semicolon).
104+
- Strips whitespace, drops empties, optionally lowercases.
105+
"""
106+
if v is None:
107+
return None
108+
if isinstance(v, (list, tuple, set)):
109+
items = list(v)
110+
else:
111+
s = str(v).strip()
112+
if not s:
113+
return None
114+
if allow_json and s.startswith("[") and s.endswith("]"):
115+
try:
116+
parsed = json.loads(s)
117+
items = parsed if isinstance(parsed, list) else [s]
118+
except Exception:
119+
items = sep_re.split(s)
120+
else:
121+
items = sep_re.split(s)
122+
123+
out: List[str] = []
124+
for item in items:
125+
s = str(item).strip()
126+
if not s:
127+
continue
128+
out.append(s.lower() if lower else s)
129+
return out or None
130+
131+
132+
def dedupe_preserve_order(items: Iterable[str]) -> List[str]:
133+
seen = set()
134+
out: List[str] = []
135+
for x in items:
136+
if x not in seen:
137+
seen.add(x)
138+
out.append(x)
139+
return out

deepeval/constants.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from enum import Enum
2+
13
KEY_FILE: str = ".deepeval"
24
HIDDEN_DIR: str = ".deepeval"
35
PYTEST_RUN_TEST_NAME: str = "CONFIDENT_AI_RUN_TEST_NAME"
@@ -11,3 +13,28 @@
1113
CONFIDENT_TRACING_ENABLED = "CONFIDENT_TRACING_ENABLED"
1214
CONFIDENT_OPEN_BROWSER = "CONFIDENT_OPEN_BROWSER"
1315
CONFIDENT_TEST_CASE_BATCH_SIZE = "CONFIDENT_TEST_CASE_BATCH_SIZE"
16+
17+
18+
class ProviderSlug(str, Enum):
19+
OPENAI = "openai"
20+
AZURE = "azure"
21+
ANTHROPIC = "anthropic"
22+
BEDROCK = "bedrock"
23+
DEEPSEEK = "deepseek"
24+
GOOGLE = "google"
25+
GROK = "grok"
26+
KIMI = "kimi"
27+
LITELLM = "litellm"
28+
LOCAL = "local"
29+
OLLAMA = "ollama"
30+
31+
32+
def slugify(value: str | ProviderSlug) -> str:
33+
return (
34+
value.value
35+
if isinstance(value, ProviderSlug)
36+
else str(value).strip().lower()
37+
)
38+
39+
40+
SUPPORTED_PROVIDER_SLUGS = frozenset(s.value for s in ProviderSlug)

deepeval/models/embedding_models/azure_embedding_model.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1-
from typing import List
1+
from typing import Dict, List
22
from openai import AzureOpenAI, AsyncAzureOpenAI
33
from deepeval.key_handler import (
44
EmbeddingKeyValues,
55
ModelKeyValues,
66
KEY_FILE_HANDLER,
77
)
88
from deepeval.models import DeepEvalBaseEmbeddingModel
9+
from deepeval.models.retry_policy import (
10+
create_retry_decorator,
11+
sdk_retries_for,
12+
)
13+
from deepeval.constants import ProviderSlug as PS
14+
15+
16+
retry_azure = create_retry_decorator(PS.AZURE)
917

1018

1119
class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
12-
def __init__(self):
20+
def __init__(self, **kwargs):
1321
self.azure_openai_api_key = KEY_FILE_HANDLER.fetch_data(
1422
ModelKeyValues.AZURE_OPENAI_API_KEY
1523
)
@@ -23,7 +31,9 @@ def __init__(self):
2331
ModelKeyValues.AZURE_OPENAI_ENDPOINT
2432
)
2533
self.model_name = self.azure_embedding_deployment
34+
self.kwargs = kwargs
2635

36+
@retry_azure
2737
def embed_text(self, text: str) -> List[float]:
2838
client = self.load_model(async_mode=False)
2939
response = client.embeddings.create(
@@ -32,6 +42,7 @@ def embed_text(self, text: str) -> List[float]:
3242
)
3343
return response.data[0].embedding
3444

45+
@retry_azure
3546
def embed_texts(self, texts: List[str]) -> List[List[float]]:
3647
client = self.load_model(async_mode=False)
3748
response = client.embeddings.create(
@@ -40,6 +51,7 @@ def embed_texts(self, texts: List[str]) -> List[List[float]]:
4051
)
4152
return [item.embedding for item in response.data]
4253

54+
@retry_azure
4355
async def a_embed_text(self, text: str) -> List[float]:
4456
client = self.load_model(async_mode=True)
4557
response = await client.embeddings.create(
@@ -48,6 +60,7 @@ async def a_embed_text(self, text: str) -> List[float]:
4860
)
4961
return response.data[0].embedding
5062

63+
@retry_azure
5164
async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
5265
client = self.load_model(async_mode=True)
5366
response = await client.embeddings.create(
@@ -61,15 +74,33 @@ def get_model_name(self) -> str:
6174

6275
def load_model(self, async_mode: bool = False):
6376
if not async_mode:
64-
return AzureOpenAI(
65-
api_key=self.azure_openai_api_key,
66-
api_version=self.openai_api_version,
67-
azure_endpoint=self.azure_endpoint,
68-
azure_deployment=self.azure_embedding_deployment,
69-
)
70-
return AsyncAzureOpenAI(
77+
return self._build_client(AzureOpenAI)
78+
return self._build_client(AsyncAzureOpenAI)
79+
80+
def _client_kwargs(self) -> Dict:
81+
"""
82+
If Tenacity is managing retries, force OpenAI SDK retries off to avoid double retries.
83+
If the user opts into SDK retries for 'azure' via DEEPEVAL_SDK_RETRY_PROVIDERS,
84+
leave their retry settings as is.
85+
"""
86+
kwargs = dict(self.kwargs or {})
87+
if not sdk_retries_for(PS.AZURE):
88+
kwargs["max_retries"] = 0
89+
return kwargs
90+
91+
def _build_client(self, cls):
92+
kw = dict(
7193
api_key=self.azure_openai_api_key,
7294
api_version=self.openai_api_version,
7395
azure_endpoint=self.azure_endpoint,
7496
azure_deployment=self.azure_embedding_deployment,
97+
**self._client_kwargs(),
7598
)
99+
try:
100+
return cls(**kw)
101+
except TypeError as e:
102+
# older OpenAI SDKs may not accept max_retries, in that case remove and retry once
103+
if "max_retries" in str(e):
104+
kw.pop("max_retries", None)
105+
return cls(**kw)
106+
raise

0 commit comments

Comments
 (0)