1414
1515import httpx
1616from fastapi import HTTPException
17- from tenacity import retry , retry_if_exception_type , stop_after_attempt , wait_exponential
17+ from tenacity import retry , retry_if_exception , retry_if_exception_type , stop_after_attempt , wait_exponential
1818
1919logger = logging .getLogger (__name__ )
2020
@@ -53,6 +53,7 @@ def __init__(self, config: CircuitBreakerConfig):
5353 self .failure_count = 0
5454 self .last_failure_time : Optional [float ] = None
5555 self .next_retry_time : Optional [float ] = None
56+ self ._lock = asyncio .Lock ()
5657
5758 def __call__ (self , func : Callable ) -> Callable :
5859 """Decorator to apply circuit breaker to function."""
@@ -66,19 +67,20 @@ async def wrapper(*args, **kwargs):
6667 async def _call (self , func : Callable , * args , ** kwargs ) -> Any :
6768 """Execute function with circuit breaker logic."""
6869
69- # Check if circuit is open
70- if self .state == CircuitBreakerState .OPEN :
71- if self ._should_attempt_reset ():
72- self .state = CircuitBreakerState .HALF_OPEN
73- logger .info (f"Circuit breaker { func .__name__ } moved to HALF_OPEN" )
74- else :
75- if self .config .fallback_function :
76- logger .warning (f"Circuit breaker { func .__name__ } OPEN, using fallback" )
77- return await self .config .fallback_function (* args , ** kwargs )
70+ # Check if circuit is open (under lock to prevent race conditions)
71+ async with self ._lock :
72+ if self .state == CircuitBreakerState .OPEN :
73+ if self ._should_attempt_reset ():
74+ self .state = CircuitBreakerState .HALF_OPEN
75+ logger .info (f"Circuit breaker { func .__name__ } moved to HALF_OPEN" )
7876 else :
79- raise HTTPException (
80- status_code = 503 , detail = "Service temporarily unavailable - circuit breaker open"
81- )
77+ if self .config .fallback_function :
78+ logger .warning (f"Circuit breaker { func .__name__ } OPEN, using fallback" )
79+ return await self .config .fallback_function (* args , ** kwargs )
80+ else :
81+ raise HTTPException (
82+ status_code = 503 , detail = "Service temporarily unavailable - circuit breaker open"
83+ )
8284
8385 try :
8486 result = await func (* args , ** kwargs )
@@ -101,23 +103,24 @@ def _should_attempt_reset(self) -> bool:
101103
102104 async def _on_success (self , func_name : str ):
103105 """Handle successful execution."""
104- if self .state == CircuitBreakerState . HALF_OPEN :
105- self .state = CircuitBreakerState .CLOSED
106- logger . info ( f"Circuit breaker { func_name } reset to CLOSED" )
107-
108- self .failure_count = 0
109- self .last_failure_time = None
106+ async with self ._lock :
107+ if self .state == CircuitBreakerState .HALF_OPEN :
108+ self . state = CircuitBreakerState . CLOSED
109+ logger . info ( f"Circuit breaker { func_name } reset to CLOSED" )
110+ self .failure_count = 0
111+ self .last_failure_time = None
110112
111113 async def _on_failure (self , func_name : str , exception : Exception ):
112114 """Handle failed execution."""
113- self .failure_count += 1
114- self .last_failure_time = time .time ()
115-
116- if self .failure_count >= self .config .failure_threshold :
117- self .state = CircuitBreakerState .OPEN
118- logger .warning (
119- f"Circuit breaker { func_name } opened after { self .failure_count } failures. " f"Last error: { exception } "
120- )
115+ async with self ._lock :
116+ self .failure_count += 1
117+ self .last_failure_time = time .time ()
118+ if self .failure_count >= self .config .failure_threshold :
119+ self .state = CircuitBreakerState .OPEN
120+ logger .warning (
121+ f"Circuit breaker { func_name } opened after { self .failure_count } failures. "
122+ f"Last error: { exception } "
123+ )
121124
122125
123126class TimeoutManager :
@@ -155,13 +158,22 @@ async def wrapper(*args, **kwargs):
155158 return decorator
156159
157160
161+ def _is_retryable (exception ):
162+ """Only retry on transient errors (5xx, connection, timeout)."""
163+ if isinstance (exception , httpx .RequestError ):
164+ return True # Connection/timeout errors
165+ if isinstance (exception , httpx .HTTPStatusError ):
166+ return exception .response .status_code >= 500
167+ return False
168+
169+
158170class RetryConfig :
159171 """Retry configuration."""
160172
161173 WAZUH_API_RETRY = retry (
162174 stop = stop_after_attempt (3 ),
163175 wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
164- retry = retry_if_exception_type (( httpx . RequestError , httpx . HTTPStatusError ) ),
176+ retry = retry_if_exception ( _is_retryable ),
165177 reraise = True ,
166178 )
167179
0 commit comments