Skip to content

Commit b8161f7

Browse files
feat: refactor retryers
1 parent b20c542 commit b8161f7

7 files changed

Lines changed: 454 additions & 326 deletions

File tree

src/uipath_langchain/_utils/_retry_after_strategy.py

Lines changed: 0 additions & 215 deletions
This file was deleted.

src/uipath_langchain/chat/bedrock.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
import os
33
from typing import Optional
44

5+
from tenacity import AsyncRetrying, Retrying
56
from uipath._utils import resource_override
67
from uipath.utils import EndpointManager
78

8-
from .._utils._retry_after_strategy import (
9-
AsyncRetryAfterHeaderStrategy,
10-
RetryAfterHeaderStrategy,
11-
)
9+
from .retryers.bedrock import AsyncBedrockRetryer, BedrockRetryer
1210
from .supported_models import BedrockModels
1311
from .types import APIFlavor, LLMProvider
1412

@@ -43,18 +41,11 @@ def _check_bedrock_dependencies() -> None:
4341

4442
import boto3
4543
import botocore.config
46-
import botocore.exceptions
4744
from langchain_aws import (
4845
ChatBedrock,
4946
ChatBedrockConverse,
5047
)
5148

52-
_BEDROCK_RETRY_EXCEPTIONS = (
53-
botocore.exceptions.ReadTimeoutError,
54-
botocore.exceptions.ConnectTimeoutError,
55-
botocore.exceptions.EndpointConnectionError,
56-
)
57-
5849

5950
class AwsBedrockCompletionsPassthroughClient:
6051
@resource_override(
@@ -145,7 +136,8 @@ class UiPathChatBedrockConverse(ChatBedrockConverse):
145136
llm_provider: LLMProvider = LLMProvider.BEDROCK
146137
api_flavor: APIFlavor = APIFlavor.AWS_BEDROCK_CONVERSE
147138
model: str = "" # For tracing serialization
148-
max_retries: int = 5
139+
retryer: Optional[Retrying] = None
140+
aretryer: Optional[AsyncRetrying] = None
149141

150142
def __init__(
151143
self,
@@ -155,6 +147,8 @@ def __init__(
155147
model_name: str = BedrockModels.anthropic_claude_haiku_4_5,
156148
agenthub_config: Optional[str] = None,
157149
byo_connection_id: Optional[str] = None,
150+
retryer: Optional[Retrying] = None,
151+
aretryer: Optional[AsyncRetrying] = None,
158152
**kwargs,
159153
):
160154
org_id = org_id or os.getenv("UIPATH_ORGANIZATION_ID")
@@ -187,29 +181,24 @@ def __init__(
187181
kwargs["model"] = model_name
188182
super().__init__(**kwargs)
189183
self.model = model_name
184+
self.retryer = retryer
185+
self.aretryer = aretryer
190186

191187
def invoke(self, *args, **kwargs):
192-
retryer = RetryAfterHeaderStrategy(
193-
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
194-
max_retries=self.max_retries,
195-
logger=logger,
196-
)
188+
retryer = self.retryer or _get_default_retryer()
197189
return retryer(super().invoke, *args, **kwargs)
198190

199191
async def ainvoke(self, *args, **kwargs):
200-
retryer = AsyncRetryAfterHeaderStrategy(
201-
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
202-
max_retries=self.max_retries,
203-
logger=logger,
204-
)
192+
retryer = self.aretryer or _get_default_async_retryer()
205193
return await retryer(super().ainvoke, *args, **kwargs)
206194

207195

208196
class UiPathChatBedrock(ChatBedrock):
209197
llm_provider: LLMProvider = LLMProvider.BEDROCK
210198
api_flavor: APIFlavor = APIFlavor.AWS_BEDROCK_INVOKE
211199
model: str = "" # For tracing serialization
212-
max_retries: int = 5
200+
retryer: Optional[Retrying] = None
201+
aretryer: Optional[AsyncRetrying] = None
213202

214203
def __init__(
215204
self,
@@ -219,6 +208,8 @@ def __init__(
219208
model_name: str = BedrockModels.anthropic_claude_haiku_4_5,
220209
agenthub_config: Optional[str] = None,
221210
byo_connection_id: Optional[str] = None,
211+
retryer: Optional[Retrying] = None,
212+
aretryer: Optional[AsyncRetrying] = None,
222213
**kwargs,
223214
):
224215
org_id = org_id or os.getenv("UIPATH_ORGANIZATION_ID")
@@ -251,19 +242,21 @@ def __init__(
251242
kwargs["model"] = model_name
252243
super().__init__(**kwargs)
253244
self.model = model_name
245+
self.retryer = retryer
246+
self.aretryer = aretryer
254247

255248
def invoke(self, *args, **kwargs):
256-
retryer = RetryAfterHeaderStrategy(
257-
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
258-
max_retries=self.max_retries,
259-
logger=logger,
260-
)
249+
retryer = self.retryer or _get_default_retryer()
261250
return retryer(super().invoke, *args, **kwargs)
262251

263252
async def ainvoke(self, *args, **kwargs):
264-
retryer = AsyncRetryAfterHeaderStrategy(
265-
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
266-
max_retries=self.max_retries,
267-
logger=logger,
268-
)
253+
retryer = self.aretryer or _get_default_async_retryer()
269254
return await retryer(super().ainvoke, *args, **kwargs)
255+
256+
257+
def _get_default_retryer() -> BedrockRetryer:
258+
return BedrockRetryer(logger=logger)
259+
260+
261+
def _get_default_async_retryer() -> AsyncBedrockRetryer:
262+
return AsyncBedrockRetryer(logger=logger)

0 commit comments

Comments
 (0)