22import os
33from typing import Optional
44
5+ from tenacity import AsyncRetrying , Retrying
56from uipath ._utils import resource_override
67from uipath .utils import EndpointManager
78
8- from .._utils ._retry_after_strategy import (
9- AsyncRetryAfterHeaderStrategy ,
10- RetryAfterHeaderStrategy ,
11- )
9+ from .retryers .bedrock import AsyncBedrockRetryer , BedrockRetryer
1210from .supported_models import BedrockModels
1311from .types import APIFlavor , LLMProvider
1412
@@ -43,18 +41,11 @@ def _check_bedrock_dependencies() -> None:
4341
4442import boto3
4543import botocore .config
46- import botocore .exceptions
4744from langchain_aws import (
4845 ChatBedrock ,
4946 ChatBedrockConverse ,
5047)
5148
52- _BEDROCK_RETRY_EXCEPTIONS = (
53- botocore .exceptions .ReadTimeoutError ,
54- botocore .exceptions .ConnectTimeoutError ,
55- botocore .exceptions .EndpointConnectionError ,
56- )
57-
5849
5950class AwsBedrockCompletionsPassthroughClient :
6051 @resource_override (
@@ -145,7 +136,8 @@ class UiPathChatBedrockConverse(ChatBedrockConverse):
145136 llm_provider : LLMProvider = LLMProvider .BEDROCK
146137 api_flavor : APIFlavor = APIFlavor .AWS_BEDROCK_CONVERSE
147138 model : str = "" # For tracing serialization
148- max_retries : int = 5
139+ retryer : Optional [Retrying ] = None
140+ aretryer : Optional [AsyncRetrying ] = None
149141
150142 def __init__ (
151143 self ,
@@ -155,6 +147,8 @@ def __init__(
155147 model_name : str = BedrockModels .anthropic_claude_haiku_4_5 ,
156148 agenthub_config : Optional [str ] = None ,
157149 byo_connection_id : Optional [str ] = None ,
150+ retryer : Optional [Retrying ] = None ,
151+ aretryer : Optional [AsyncRetrying ] = None ,
158152 ** kwargs ,
159153 ):
160154 org_id = org_id or os .getenv ("UIPATH_ORGANIZATION_ID" )
@@ -187,29 +181,24 @@ def __init__(
187181 kwargs ["model" ] = model_name
188182 super ().__init__ (** kwargs )
189183 self .model = model_name
184+ self .retryer = retryer
185+ self .aretryer = aretryer
190186
191187 def invoke (self , * args , ** kwargs ):
192- retryer = RetryAfterHeaderStrategy (
193- retry_on_exceptions = _BEDROCK_RETRY_EXCEPTIONS ,
194- max_retries = self .max_retries ,
195- logger = logger ,
196- )
188+ retryer = self .retryer or _get_default_retryer ()
197189 return retryer (super ().invoke , * args , ** kwargs )
198190
199191 async def ainvoke (self , * args , ** kwargs ):
200- retryer = AsyncRetryAfterHeaderStrategy (
201- retry_on_exceptions = _BEDROCK_RETRY_EXCEPTIONS ,
202- max_retries = self .max_retries ,
203- logger = logger ,
204- )
192+ retryer = self .aretryer or _get_default_async_retryer ()
205193 return await retryer (super ().ainvoke , * args , ** kwargs )
206194
207195
208196class UiPathChatBedrock (ChatBedrock ):
209197 llm_provider : LLMProvider = LLMProvider .BEDROCK
210198 api_flavor : APIFlavor = APIFlavor .AWS_BEDROCK_INVOKE
211199 model : str = "" # For tracing serialization
212- max_retries : int = 5
200+ retryer : Optional [Retrying ] = None
201+ aretryer : Optional [AsyncRetrying ] = None
213202
214203 def __init__ (
215204 self ,
@@ -219,6 +208,8 @@ def __init__(
219208 model_name : str = BedrockModels .anthropic_claude_haiku_4_5 ,
220209 agenthub_config : Optional [str ] = None ,
221210 byo_connection_id : Optional [str ] = None ,
211+ retryer : Optional [Retrying ] = None ,
212+ aretryer : Optional [AsyncRetrying ] = None ,
222213 ** kwargs ,
223214 ):
224215 org_id = org_id or os .getenv ("UIPATH_ORGANIZATION_ID" )
@@ -251,19 +242,21 @@ def __init__(
251242 kwargs ["model" ] = model_name
252243 super ().__init__ (** kwargs )
253244 self .model = model_name
245+ self .retryer = retryer
246+ self .aretryer = aretryer
254247
255248 def invoke (self , * args , ** kwargs ):
256- retryer = RetryAfterHeaderStrategy (
257- retry_on_exceptions = _BEDROCK_RETRY_EXCEPTIONS ,
258- max_retries = self .max_retries ,
259- logger = logger ,
260- )
249+ retryer = self .retryer or _get_default_retryer ()
261250 return retryer (super ().invoke , * args , ** kwargs )
262251
263252 async def ainvoke (self , * args , ** kwargs ):
264- retryer = AsyncRetryAfterHeaderStrategy (
265- retry_on_exceptions = _BEDROCK_RETRY_EXCEPTIONS ,
266- max_retries = self .max_retries ,
267- logger = logger ,
268- )
253+ retryer = self .aretryer or _get_default_async_retryer ()
269254 return await retryer (super ().ainvoke , * args , ** kwargs )
255+
256+
257+ def _get_default_retryer () -> BedrockRetryer :
258+ return BedrockRetryer (logger = logger )
259+
260+
261+ def _get_default_async_retryer () -> AsyncBedrockRetryer :
262+ return AsyncBedrockRetryer (logger = logger )
0 commit comments