1- from typing import List
1+ from typing import Dict , List
22from openai import AzureOpenAI , AsyncAzureOpenAI
33from deepeval .key_handler import (
44 EmbeddingKeyValues ,
55 ModelKeyValues ,
66 KEY_FILE_HANDLER ,
77)
88from deepeval .models import DeepEvalBaseEmbeddingModel
9+ from deepeval .models .retry_policy import (
10+ create_retry_decorator ,
11+ sdk_retries_for ,
12+ )
13+ from deepeval .constants import ProviderSlug as PS
14+
15+
16+ retry_azure = create_retry_decorator (PS .AZURE )
917
1018
1119class AzureOpenAIEmbeddingModel (DeepEvalBaseEmbeddingModel ):
12- def __init__ (self ):
20+ def __init__ (self , ** kwargs ):
1321 self .azure_openai_api_key = KEY_FILE_HANDLER .fetch_data (
1422 ModelKeyValues .AZURE_OPENAI_API_KEY
1523 )
@@ -23,7 +31,9 @@ def __init__(self):
2331 ModelKeyValues .AZURE_OPENAI_ENDPOINT
2432 )
2533 self .model_name = self .azure_embedding_deployment
34+ self .kwargs = kwargs
2635
36+ @retry_azure
2737 def embed_text (self , text : str ) -> List [float ]:
2838 client = self .load_model (async_mode = False )
2939 response = client .embeddings .create (
@@ -32,6 +42,7 @@ def embed_text(self, text: str) -> List[float]:
3242 )
3343 return response .data [0 ].embedding
3444
45+ @retry_azure
3546 def embed_texts (self , texts : List [str ]) -> List [List [float ]]:
3647 client = self .load_model (async_mode = False )
3748 response = client .embeddings .create (
@@ -40,6 +51,7 @@ def embed_texts(self, texts: List[str]) -> List[List[float]]:
4051 )
4152 return [item .embedding for item in response .data ]
4253
54+ @retry_azure
4355 async def a_embed_text (self , text : str ) -> List [float ]:
4456 client = self .load_model (async_mode = True )
4557 response = await client .embeddings .create (
@@ -48,6 +60,7 @@ async def a_embed_text(self, text: str) -> List[float]:
4860 )
4961 return response .data [0 ].embedding
5062
63+ @retry_azure
5164 async def a_embed_texts (self , texts : List [str ]) -> List [List [float ]]:
5265 client = self .load_model (async_mode = True )
5366 response = await client .embeddings .create (
@@ -61,15 +74,33 @@ def get_model_name(self) -> str:
6174
6275 def load_model (self , async_mode : bool = False ):
6376 if not async_mode :
64- return AzureOpenAI (
65- api_key = self .azure_openai_api_key ,
66- api_version = self .openai_api_version ,
67- azure_endpoint = self .azure_endpoint ,
68- azure_deployment = self .azure_embedding_deployment ,
69- )
70- return AsyncAzureOpenAI (
77+ return self ._build_client (AzureOpenAI )
78+ return self ._build_client (AsyncAzureOpenAI )
79+
80+ def _client_kwargs (self ) -> Dict :
81+ """
82+ If Tenacity is managing retries, force OpenAI SDK retries off to avoid double retries.
83+ If the user opts into SDK retries for 'azure' via DEEPEVAL_SDK_RETRY_PROVIDERS,
84+ leave their retry settings as is.
85+ """
86+ kwargs = dict (self .kwargs or {})
87+ if not sdk_retries_for (PS .AZURE ):
88+ kwargs ["max_retries" ] = 0
89+ return kwargs
90+
91+ def _build_client (self , cls ):
92+ kw = dict (
7193 api_key = self .azure_openai_api_key ,
7294 api_version = self .openai_api_version ,
7395 azure_endpoint = self .azure_endpoint ,
7496 azure_deployment = self .azure_embedding_deployment ,
97+ ** self ._client_kwargs (),
7598 )
99+ try :
100+ return cls (** kw )
101+ except TypeError as e :
102+ # older OpenAI SDKs may not accept max_retries, in that case remove and retry once
103+ if "max_retries" in str (e ):
104+ kw .pop ("max_retries" , None )
105+ return cls (** kw )
106+ raise
0 commit comments