Skip to content

Commit afe66a1

Browse files
feat: retry completions according to Retry-After header
1 parent c1847be commit afe66a1

4 files changed

Lines changed: 582 additions & 0 deletions

File tree

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""Reusable tenacity retry strategy that honors HTTP Retry-After headers.
2+
3+
Supports only durations in the header, not datetimes.
4+
5+
Uses duck typing to extract retry information from exceptions,
6+
to avoid importing exception types from the different chat SDKs.
7+
"""
8+
9+
import logging
10+
import random
11+
from typing import Callable, Mapping
12+
13+
from tenacity import (
14+
AsyncRetrying,
15+
RetryCallState,
16+
Retrying,
17+
stop_after_attempt,
18+
)
19+
20+
RETRYABLE_STATUS_CODES = {408, 429, 502, 503, 504}
21+
22+
23+
def _extract_retry_after_header(
24+
exception: Exception | BaseException,
25+
) -> float | None:
26+
def _parse_retry_after(header_value: str) -> float | None:
27+
try:
28+
seconds = float(header_value.strip())
29+
if seconds < 0:
30+
return None
31+
return seconds
32+
except (ValueError, AttributeError):
33+
return None
34+
35+
def _extract_from_headers(headers: Mapping[str, str]) -> float | None:
36+
retry_after = headers.get("retry-after") or headers.get("Retry-After")
37+
if retry_after:
38+
parsed = _parse_retry_after(retry_after)
39+
if parsed is not None:
40+
return parsed
41+
return None
42+
43+
current: Exception | BaseException | None = exception
44+
while current:
45+
if hasattr(current, "response"):
46+
response = current.response
47+
48+
# httpx, google.genai structure
49+
if hasattr(response, "headers"):
50+
result = _extract_from_headers(response.headers)
51+
if result is not None:
52+
return result
53+
54+
# botocore structure
55+
if isinstance(response, dict) and "ResponseMetadata" in response:
56+
headers = response.get("ResponseMetadata", {}).get("HTTPHeaders", {})
57+
result = _extract_from_headers(headers)
58+
if result is not None:
59+
return result
60+
61+
current = current.__cause__
62+
63+
return None
64+
65+
66+
def _extract_status_code(exception: Exception | BaseException) -> int | None:
67+
if hasattr(exception, "response"):
68+
response = exception.response
69+
# httpx, google.genai structure
70+
if hasattr(response, "status_code"):
71+
return response.status_code
72+
# botocore structure
73+
if isinstance(response, dict) and "ResponseMetadata" in response:
74+
return response.get("ResponseMetadata", {}).get("HTTPStatusCode")
75+
return None
76+
77+
78+
def _is_retryable_exception(
79+
exception: Exception | BaseException,
80+
retry_on_exceptions: tuple[type[Exception], ...],
81+
) -> bool:
82+
current: Exception | BaseException | None = exception
83+
while current is not None:
84+
if isinstance(current, retry_on_exceptions):
85+
# Exception that should always be retried
86+
return True
87+
88+
status_code = _extract_status_code(current)
89+
if status_code is not None and status_code in RETRYABLE_STATUS_CODES:
90+
return True
91+
92+
retry_after = _extract_retry_after_header(current)
93+
if retry_after is not None:
94+
# Always retry if server requested it
95+
return True
96+
97+
current = current.__cause__
98+
return False
99+
100+
101+
def _create_retry_condition(
102+
retry_on_exceptions: tuple[type[Exception], ...],
103+
) -> Callable[[RetryCallState], bool]:
104+
def retry_condition(retry_state: RetryCallState) -> bool:
105+
if retry_state.outcome is None:
106+
return False
107+
exception = retry_state.outcome.exception()
108+
if exception is None:
109+
return False
110+
return _is_retryable_exception(exception, retry_on_exceptions)
111+
112+
return retry_condition
113+
114+
115+
def _create_retry_after_wait_strategy(
116+
initial: float = 5.0,
117+
max_delay: float = 180.0,
118+
logger: logging.Logger | None = None,
119+
) -> Callable[[RetryCallState], float]:
120+
"""Create a wait strategy that honors the Retry-After header if present, falling back to exponential backoff."""
121+
122+
def _exponential_backoff(attempt: int, initial: float) -> float:
123+
exponent = attempt - 1
124+
exponential = initial * (2**exponent)
125+
jitter = random.uniform(0, 1.0)
126+
return exponential + jitter
127+
128+
def wait_strategy(retry_state: RetryCallState) -> float:
129+
"""Calculate wait time based on exception and retry state."""
130+
if retry_state.outcome is None:
131+
return initial
132+
133+
exception = retry_state.outcome.exception()
134+
if exception is not None:
135+
retry_after = _extract_retry_after_header(exception)
136+
if retry_after is not None:
137+
capped_wait = min(retry_after, max_delay)
138+
if logger:
139+
logger.info(
140+
f"Retrying after {retry_after:.1f}s"
141+
f"{f' (capped to {capped_wait:.1f}s)' if capped_wait != retry_after else ''}"
142+
)
143+
return capped_wait
144+
145+
exponential_wait = _exponential_backoff(retry_state.attempt_number, initial)
146+
capped_wait = min(exponential_wait, max_delay)
147+
if logger:
148+
logger.info(
149+
f"Retrying with exponential backoff after {capped_wait:.1f}s (attempt #{retry_state.attempt_number})"
150+
)
151+
return capped_wait
152+
153+
return wait_strategy
154+
155+
156+
class RetryAfterHeaderStrategy(Retrying):
157+
"""Synchronous retry strategy with Retry-After header support.
158+
159+
Args:
160+
retry_on_exceptions: Exception types that should always be retried
161+
max_retries: Maximum number of retry attempts
162+
initial: Initial delay for exponential backoff in seconds
163+
max_delay: Maximum delay between retries in seconds
164+
logger: Optional logger for retry events
165+
"""
166+
167+
def __init__(
168+
self,
169+
retry_on_exceptions: tuple[type[Exception], ...] = (),
170+
max_retries: int = 5,
171+
initial: float = 5.0,
172+
max_delay: float = 120.0,
173+
logger: logging.Logger | None = None,
174+
):
175+
super().__init__(
176+
wait=_create_retry_after_wait_strategy(
177+
initial=initial,
178+
max_delay=max_delay,
179+
logger=logger,
180+
),
181+
retry=_create_retry_condition(retry_on_exceptions),
182+
stop=stop_after_attempt(max_retries),
183+
reraise=True,
184+
)
185+
186+
187+
class AsyncRetryAfterHeaderStrategy(AsyncRetrying):
188+
"""Asynchronous retry strategy with Retry-After header support.
189+
190+
Args:
191+
retry_on_exceptions: Exception types that should always be retried
192+
max_retries: Maximum number of retry attempts
193+
initial: Initial delay for exponential backoff in seconds
194+
max_delay: Maximum delay between retries in seconds
195+
logger: Optional logger for retry events
196+
"""
197+
198+
def __init__(
199+
self,
200+
retry_on_exceptions: tuple[type[Exception], ...] = (),
201+
max_retries: int = 5,
202+
initial: float = 5.0,
203+
max_delay: float = 120.0,
204+
logger: logging.Logger | None = None,
205+
):
206+
super().__init__(
207+
wait=_create_retry_after_wait_strategy(
208+
initial=initial,
209+
max_delay=max_delay,
210+
logger=logger,
211+
),
212+
retry=_create_retry_condition(retry_on_exceptions),
213+
stop=stop_after_attempt(max_retries),
214+
reraise=True,
215+
)

src/uipath_langchain/chat/bedrock.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from uipath._utils import resource_override
66
from uipath.utils import EndpointManager
77

8+
from .._utils._retry_after_strategy import (
9+
AsyncRetryAfterHeaderStrategy,
10+
RetryAfterHeaderStrategy,
11+
)
812
from .supported_models import BedrockModels
913
from .types import APIFlavor, LLMProvider
1014

@@ -38,11 +42,19 @@ def _check_bedrock_dependencies() -> None:
3842
_check_bedrock_dependencies()
3943

4044
import boto3
45+
import botocore.config
46+
import botocore.exceptions
4147
from langchain_aws import (
4248
ChatBedrock,
4349
ChatBedrockConverse,
4450
)
4551

52+
_BEDROCK_RETRY_EXCEPTIONS = (
53+
botocore.exceptions.ReadTimeoutError,
54+
botocore.exceptions.ConnectTimeoutError,
55+
botocore.exceptions.EndpointConnectionError,
56+
)
57+
4658

4759
class AwsBedrockCompletionsPassthroughClient:
4860
@resource_override(
@@ -90,6 +102,11 @@ def get_client(self):
90102
region_name="none",
91103
aws_access_key_id="none",
92104
aws_secret_access_key="none",
105+
config=botocore.config.Config(
106+
retries={
107+
"total_max_attempts": 1,
108+
}
109+
),
93110
)
94111
client.meta.events.register(
95112
"before-send.bedrock-runtime.*", self._modify_request
@@ -128,6 +145,7 @@ class UiPathChatBedrockConverse(ChatBedrockConverse):
128145
llm_provider: LLMProvider = LLMProvider.BEDROCK
129146
api_flavor: APIFlavor = APIFlavor.AWS_BEDROCK_CONVERSE
130147
model: str = "" # For tracing serialization
148+
max_retries: int = 5
131149

132150
def __init__(
133151
self,
@@ -170,11 +188,28 @@ def __init__(
170188
super().__init__(**kwargs)
171189
self.model = model_name
172190

191+
def invoke(self, *args, **kwargs):
192+
retryer = RetryAfterHeaderStrategy(
193+
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
194+
max_retries=self.max_retries,
195+
logger=logger,
196+
)
197+
return retryer(super().invoke, *args, **kwargs)
198+
199+
async def ainvoke(self, *args, **kwargs):
200+
retryer = AsyncRetryAfterHeaderStrategy(
201+
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
202+
max_retries=self.max_retries,
203+
logger=logger,
204+
)
205+
return await retryer(super().ainvoke, *args, **kwargs)
206+
173207

174208
class UiPathChatBedrock(ChatBedrock):
175209
llm_provider: LLMProvider = LLMProvider.BEDROCK
176210
api_flavor: APIFlavor = APIFlavor.AWS_BEDROCK_INVOKE
177211
model: str = "" # For tracing serialization
212+
max_retries: int = 5
178213

179214
def __init__(
180215
self,
@@ -216,3 +251,19 @@ def __init__(
216251
kwargs["model"] = model_name
217252
super().__init__(**kwargs)
218253
self.model = model_name
254+
255+
def invoke(self, *args, **kwargs):
256+
retryer = RetryAfterHeaderStrategy(
257+
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
258+
max_retries=self.max_retries,
259+
logger=logger,
260+
)
261+
return retryer(super().invoke, *args, **kwargs)
262+
263+
async def ainvoke(self, *args, **kwargs):
264+
retryer = AsyncRetryAfterHeaderStrategy(
265+
retry_on_exceptions=_BEDROCK_RETRY_EXCEPTIONS,
266+
max_retries=self.max_retries,
267+
logger=logger,
268+
)
269+
return await retryer(super().ainvoke, *args, **kwargs)

0 commit comments

Comments
 (0)