11from dataclasses import dataclass
22from logging import Logger
33from pydantic import BaseModel
4- from typing import AsyncGenerator , Any , TypeVar , Callable , Awaitable
4+ from typing import AsyncGenerator , Any , Awaitable , TypeVar , Callable , Protocol
55import abc
66import aiohttp
77import asyncio
4343BodyGeneratorFunction = Callable [[], AsyncGenerator [bytes , None ]]
4444HeadersAndBodyGenerator = tuple [Headers , BodyGeneratorFunction ]
4545
46+ class ShouldRetryProtocol (Protocol ):
47+ """
48+ ShouldRetryProtocol defines a callback function signature for custom retry logic.
49+
50+ Implementations should return True if the HTTP request should be retried, or False
51+ if the request should fail immediately with an HTTPError. This allows connectors
52+ to implement custom retry strategies for server errors (5xx status codes) based
53+ on the response status, headers, body, and current attempt number.
54+
55+ Parameters:
56+ * `status` is the HTTP status code of the response
57+ * `headers` are the HTTP response headers as a dict
58+ * `body` is the response body as bytes
59+ * `attempt` is the current attempt number (starts at 1)
60+
61+ Example usage:
62+ def custom_retry(status: int, headers: Headers, body: bytes, attempt: int) -> bool:
63+ if status == 503 and attempt > 2:
64+ return False
65+ return status >= 500
66+ """
67+ def __call__ (
68+ self ,
69+ status : int ,
70+ headers : Headers ,
71+ body : bytes ,
72+ attempt : int
73+ ) -> bool : ...
74+
4675
4776class HTTPError (RuntimeError ):
4877 """
@@ -71,6 +100,7 @@ class HTTPSession(abc.ABC):
71100 * `params` are encoded as URL parameters of the query
72101 * `json` is a JSON-encoded request body (if set, `form` cannot be)
73102 * `form` is a form URL-encoded request body (if set, `json` cannot be)
103+ * `should_retry` determines whether 5xx errors are retried or bubbled up. If not provided, all 5xx errors are retried.
74104 """
75105
76106 async def request (
@@ -83,12 +113,13 @@ async def request(
83113 form : dict [str , Any ] | None = None ,
84114 _with_token : bool = True , # Unstable internal API.
85115 headers : dict [str , Any ] | None = None ,
116+ should_retry : ShouldRetryProtocol | None = None ,
86117 ) -> bytes :
87118 """Request a url and return its body as bytes"""
88119
89120 chunks : list [bytes ] = []
90121 _ , body_generator = await self ._request_stream (
91- log , url , method , params , json , form , _with_token , headers
122+ log , url , method , params , json , form , _with_token , headers , should_retry ,
92123 )
93124
94125 async for chunk in body_generator ():
@@ -111,11 +142,12 @@ async def request_lines(
111142 form : dict [str , Any ] | None = None ,
112143 delim : bytes = b"\n " ,
113144 headers : dict [str , Any ] | None = None ,
145+ should_retry : ShouldRetryProtocol | None = None ,
114146 ) -> tuple [Headers , BodyGeneratorFunction ]:
115147 """Request a url and return its response as streaming lines, as they arrive"""
116148
117149 resp_headers , body = await self ._request_stream (
118- log , url , method , params , json , form , True , headers
150+ log , url , method , params , json , form , True , headers , should_retry ,
119151 )
120152
121153 async def gen () -> AsyncGenerator [bytes , None ]:
@@ -141,10 +173,11 @@ async def request_stream(
141173 form : dict [str , Any ] | None = None ,
142174 _with_token : bool = True , # Unstable internal API.
143175 headers : dict [str , Any ] | None = None ,
176+ should_retry : ShouldRetryProtocol | None = None ,
144177 ) -> tuple [Headers , BodyGeneratorFunction ]:
145178 """Request a url and and return the raw response as a stream of bytes"""
146179
147- return await self ._request_stream (log , url , method , params , json , form , _with_token , headers )
180+ return await self ._request_stream (log , url , method , params , json , form , _with_token , headers , should_retry )
148181
149182
150183 @abc .abstractmethod
@@ -158,6 +191,7 @@ async def _request_stream(
158191 form : dict [str , Any ] | None ,
159192 _with_token : bool ,
160193 headers : dict [str , Any ] | None = None ,
194+ should_retry : ShouldRetryProtocol | None = None ,
161195 ) -> HeadersAndBodyGenerator : ...
162196
163197 # TODO(johnny): This is an unstable API.
@@ -484,14 +518,17 @@ async def _request_stream(
484518 form : dict [str , Any ] | None ,
485519 _with_token : bool ,
486520 headers : dict [str , Any ] | None = None ,
521+ should_retry : ShouldRetryProtocol | None = None ,
487522 ) -> HeadersAndBodyGenerator :
488523 if headers is None :
489524 headers = {}
525+ attempt = 0
490526
491527 while True :
492528 cur_delay = self .rate_limiter .delay
493529 await asyncio .sleep (cur_delay )
494530
531+ attempt += 1
495532 resp = await self ._retry_on_connection_error (
496533 log , url , method ,
497534 lambda : self ._establish_connection_and_get_response (
@@ -516,10 +553,20 @@ async def _request_stream(
516553
517554 elif resp .status >= 500 and resp .status < 600 :
518555 body = await resp .read ()
519- log .warning (
520- "server internal error (will retry)" ,
521- {"body" : body .decode ("utf-8" )},
522- )
556+
557+ if (
558+ should_retry is None
559+ or should_retry (resp .status , dict (resp .headers ), body , attempt )
560+ ):
561+ log .warning (
562+ "server internal error (will retry)" ,
563+ {"body" : body .decode ("utf-8" )},
564+ )
565+ else :
566+ raise HTTPError (
567+ f"Encountered HTTP error status { resp .status } .\n URL: { url } \n Response:\n { body .decode ('utf-8' )} " ,
568+ resp .status ,
569+ )
523570 elif resp .status >= 400 and resp .status < 500 :
524571 body = await resp .read ()
525572 raise HTTPError (
0 commit comments