|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | # Standard library imports |
| 6 | +import asyncio |
6 | 7 | import time |
7 | 8 | from dataclasses import dataclass, field |
8 | 9 | from enum import Enum |
9 | 10 | from threading import Lock, Semaphore |
10 | | -from typing import List, Optional, Tuple |
| 11 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
11 | 12 |
|
12 | 13 | # Local/package imports |
| 14 | +from ..config import get_config |
| 15 | + |
| 16 | +# Rate limiting constants |
| 17 | +# 1ms threshold to detect actual rate limit |
| 18 | +RATE_LIMIT_DETECTION_THRESHOLD_SECONDS = 0.001 |
13 | 19 |
|
14 | 20 |
|
15 | 21 | class CircuitState(Enum): |
@@ -37,21 +43,234 @@ class RateLimitMetrics: |
37 | 43 | class ProviderLimits: |
38 | 44 | """Rate limits for a specific provider.""" |
39 | 45 |
|
40 | | - max_requests_per_minute: int = 60 |
41 | | - max_concurrent_requests: int = 10 |
42 | | - tokens: int = field(default_factory=lambda: 60) |
| 46 | + max_requests_per_minute: int = 25 # Default fallback (matches Shazam default) |
| 47 | + max_concurrent_requests: int = 1 # Default fallback (matches Shazam default) |
| 48 | + tokens: int = field(init=False) |
43 | 49 | last_update: float = field(default_factory=time.time) |
44 | | - semaphore: Semaphore = field(default_factory=lambda: Semaphore(10)) |
| 50 | + semaphore: asyncio.Semaphore = field(init=False) |
45 | 51 | lock: Lock = field(default_factory=Lock) |
46 | 52 | metrics: RateLimitMetrics = field(default_factory=RateLimitMetrics) |
47 | 53 | circuit_state: CircuitState = field(default=CircuitState.CLOSED) |
48 | 54 | circuit_open_time: Optional[float] = None |
49 | 55 | consecutive_failures: int = 0 |
50 | 56 |
|
| 57 | + def __post_init__(self): |
| 58 | + self.tokens = self.max_requests_per_minute |
| 59 | + self.semaphore = asyncio.Semaphore(self.max_concurrent_requests) |
| 60 | + |
51 | 61 |
|
52 | | -@dataclass |
53 | 62 | class RateLimiter: |
54 | | - """Rate limiter implementation.""" |
| 63 | + """Advanced rate limiter with provider management, circuit breaker, metrics.""" |
| 64 | + |
| 65 | + def __init__(self, config=None): |
| 66 | + self._provider_limits: Dict[Any, ProviderLimits] = {} |
| 67 | + self._alert_callbacks: List[Callable[[str], None]] = [] |
| 68 | + # Use provided config or get global config |
| 69 | + self._config = config or get_config() |
| 70 | + |
| 71 | + def register_provider( |
| 72 | + self, |
| 73 | + provider: Any, |
| 74 | + max_requests_per_minute: int = None, |
| 75 | + max_concurrent_requests: int = None, |
| 76 | + ): |
| 77 | + """Register a provider with specific rate limits.""" |
| 78 | + # Use provided values, or fall back to config values based on provider |
| 79 | + if max_requests_per_minute is None or max_concurrent_requests is None: |
| 80 | + provider_str = str(provider).lower() |
| 81 | + |
| 82 | + # Get provider-specific limits from config |
| 83 | + if provider_str == "shazam": |
| 84 | + rpm = max_requests_per_minute or getattr( |
| 85 | + self._config, "shazam_max_rpm", 25 |
| 86 | + ) |
| 87 | + concurrent = max_concurrent_requests or getattr( |
| 88 | + self._config, "shazam_max_concurrent", 1 |
| 89 | + ) |
| 90 | + elif provider_str == "acrcloud": |
| 91 | + rpm = max_requests_per_minute or getattr( |
| 92 | + self._config, "acrcloud_max_rpm", 30 |
| 93 | + ) |
| 94 | + concurrent = max_concurrent_requests or getattr( |
| 95 | + self._config, "acrcloud_max_concurrent", 5 |
| 96 | + ) |
| 97 | + elif provider_str == "spotify": |
| 98 | + rpm = max_requests_per_minute or getattr( |
| 99 | + self._config, "spotify_max_rpm", 120 |
| 100 | + ) |
| 101 | + concurrent = max_concurrent_requests or getattr( |
| 102 | + self._config, "spotify_max_concurrent", 20 |
| 103 | + ) |
| 104 | + else: |
| 105 | + # Fall back to global config or defaults |
| 106 | + rpm = max_requests_per_minute or getattr( |
| 107 | + self._config, "max_requests_per_minute", 25 |
| 108 | + ) |
| 109 | + concurrent = max_concurrent_requests or getattr( |
| 110 | + self._config, "max_concurrent_requests", 2 |
| 111 | + ) |
| 112 | + else: |
| 113 | + rpm = max_requests_per_minute |
| 114 | + concurrent = max_concurrent_requests |
| 115 | + |
| 116 | + self._provider_limits[provider] = ProviderLimits( |
| 117 | + max_requests_per_minute=rpm, |
| 118 | + max_concurrent_requests=concurrent, |
| 119 | + ) |
| 120 | + |
| 121 | + def register_alert_callback(self, callback: Callable[[str], None]): |
| 122 | + """Register a callback for rate limiting alerts.""" |
| 123 | + self._alert_callbacks.append(callback) |
| 124 | + |
| 125 | + def _send_alert(self, message: str): |
| 126 | + """Send alert to all registered callbacks.""" |
| 127 | + for callback in self._alert_callbacks: |
| 128 | + callback(message) |
| 129 | + |
| 130 | + async def acquire(self, provider: Any, timeout: float = 30.0) -> bool: |
| 131 | + """Acquire permission to make a request.""" |
| 132 | + if provider not in self._provider_limits: |
| 133 | + self.register_provider(provider) |
| 134 | + |
| 135 | + limits = self._provider_limits[provider] |
| 136 | + |
| 137 | + # Check circuit breaker first (don't count rejected requests) |
| 138 | + circuit_breaker_enabled = getattr(self._config, "circuit_breaker_enabled", True) |
| 139 | + if circuit_breaker_enabled and limits.circuit_state == CircuitState.OPEN: |
| 140 | + circuit_reset_timeout = getattr( |
| 141 | + self._config, "circuit_breaker_reset_timeout", 60.0 |
| 142 | + ) |
| 143 | + if ( |
| 144 | + limits.circuit_open_time |
| 145 | + and time.time() - limits.circuit_open_time > circuit_reset_timeout |
| 146 | + ): |
| 147 | + limits.circuit_state = CircuitState.HALF_OPEN |
| 148 | + else: |
| 149 | + return False |
| 150 | + |
| 151 | + # Try to acquire semaphore for concurrent requests |
| 152 | + try: |
| 153 | + # Always attempt to acquire the semaphore with a timeout |
| 154 | + start_time = time.time() |
| 155 | + await asyncio.wait_for(limits.semaphore.acquire(), timeout=timeout) |
| 156 | + wait_time = time.time() - start_time |
| 157 | + limits.metrics.total_wait_time += wait_time |
| 158 | + # Note: Don't record semaphore waits as rate limit windows |
| 159 | + # This is concurrency control, not rate limiting |
| 160 | + except asyncio.TimeoutError: |
| 161 | + return False |
| 162 | + |
| 163 | + # At this point, we have semaphore access and will process the request |
| 164 | + limits.metrics.total_requests += 1 |
| 165 | + |
| 166 | + # Check if rate limiting is enabled |
| 167 | + rate_limit_enabled = getattr(self._config, "rate_limit_enabled", True) |
| 168 | + if not rate_limit_enabled: |
| 169 | + return True |
| 170 | + |
| 171 | + # Check rate limiting tokens |
| 172 | + token_wait_start = time.time() |
| 173 | + |
| 174 | + while time.time() - token_wait_start < timeout: |
| 175 | + with limits.lock: |
| 176 | + self._refill_tokens(limits) |
| 177 | + if limits.tokens > 0: |
| 178 | + limits.tokens -= 1 |
| 179 | + # Record metrics only if we had to wait for tokens (rate limiting) |
| 180 | + wait_time = time.time() - token_wait_start |
| 181 | + if wait_time >= RATE_LIMIT_DETECTION_THRESHOLD_SECONDS: |
| 182 | + # Count successful requests that were rate-limited |
| 183 | + limits.metrics.rate_limited_requests += 1 |
| 184 | + limits.metrics.last_rate_limit = time.time() |
| 185 | + limits.metrics.rate_limit_windows.append( |
| 186 | + (token_wait_start, time.time()) |
| 187 | + ) |
| 188 | + return True |
| 189 | + |
| 190 | + # Wait a short time before checking again |
| 191 | + await asyncio.sleep(0.01) |
| 192 | + |
| 193 | + # Timeout exceeded - this is a rate limiting failure |
| 194 | + # Record the window and update last_rate_limit but don't increment |
| 195 | + # rate_limited_requests (that's only for successful requests that waited) |
| 196 | + limits.metrics.last_rate_limit = time.time() |
| 197 | + limits.metrics.rate_limit_windows.append((token_wait_start, time.time())) |
| 198 | + |
| 199 | + limits.semaphore.release() |
| 200 | + return False |
| 201 | + |
| 202 | + def release(self, provider: Any): |
| 203 | + """Release a concurrent request slot.""" |
| 204 | + if provider in self._provider_limits: |
| 205 | + limits = self._provider_limits[provider] |
| 206 | + limits.semaphore.release() |
| 207 | + |
| 208 | + def _refill_tokens(self, limits: ProviderLimits): |
| 209 | + """Refill rate limiting tokens.""" |
| 210 | + now = time.time() |
| 211 | + elapsed = now - limits.last_update |
| 212 | + if elapsed >= 1.0: # Refill every second |
| 213 | + tokens_to_add = int(elapsed * (limits.max_requests_per_minute / 60)) |
| 214 | + if tokens_to_add > 0: |
| 215 | + limits.tokens = min( |
| 216 | + limits.max_requests_per_minute, limits.tokens + tokens_to_add |
| 217 | + ) |
| 218 | + limits.last_update = now |
| 219 | + |
| 220 | + def _update_circuit_breaker(self, provider: Any, success: bool): |
| 221 | + """Update circuit breaker state based on request success.""" |
| 222 | + if provider not in self._provider_limits: |
| 223 | + return |
| 224 | + |
| 225 | + limits = self._provider_limits[provider] |
| 226 | + circuit_breaker_enabled = getattr(self._config, "circuit_breaker_enabled", True) |
| 227 | + |
| 228 | + if not circuit_breaker_enabled: |
| 229 | + return |
| 230 | + |
| 231 | + if success: |
| 232 | + limits.consecutive_failures = 0 |
| 233 | + if limits.circuit_state == CircuitState.HALF_OPEN: |
| 234 | + limits.circuit_state = CircuitState.CLOSED |
| 235 | + else: |
| 236 | + limits.consecutive_failures += 1 |
| 237 | + circuit_threshold = getattr(self._config, "circuit_breaker_threshold", 5) |
| 238 | + if ( |
| 239 | + limits.consecutive_failures >= circuit_threshold |
| 240 | + and limits.circuit_state == CircuitState.CLOSED |
| 241 | + ): |
| 242 | + limits.circuit_state = CircuitState.OPEN |
| 243 | + limits.circuit_open_time = time.time() |
| 244 | + limits.metrics.circuit_trips += 1 |
| 245 | + limits.metrics.last_circuit_trip = time.time() |
| 246 | + self._send_alert( |
| 247 | + message=f"Circuit breaker opened for provider {provider} " |
| 248 | + f"after {limits.consecutive_failures} failures" |
| 249 | + ) |
| 250 | + |
| 251 | + def get_metrics(self, provider: Any) -> Dict[str, Any]: |
| 252 | + """Get metrics for a provider.""" |
| 253 | + if provider not in self._provider_limits: |
| 254 | + return {} |
| 255 | + |
| 256 | + limits = self._provider_limits[provider] |
| 257 | + return { |
| 258 | + "total_requests": limits.metrics.total_requests, |
| 259 | + "rate_limited_requests": limits.metrics.rate_limited_requests, |
| 260 | + "total_wait_time": limits.metrics.total_wait_time, |
| 261 | + "last_rate_limit": limits.metrics.last_rate_limit, |
| 262 | + "rate_limit_windows": limits.metrics.rate_limit_windows, |
| 263 | + "circuit_trips": limits.metrics.circuit_trips, |
| 264 | + "last_circuit_trip": limits.metrics.last_circuit_trip, |
| 265 | + "circuit_state": limits.circuit_state.value, |
| 266 | + "current_tokens": limits.tokens, |
| 267 | + } |
| 268 | + |
| 269 | + |
| 270 | +# Legacy support for the simple RateLimiter |
| 271 | +@dataclass |
| 272 | +class SimpleLimiter: |
| 273 | + """Simple rate limiter implementation.""" |
55 | 274 |
|
56 | 275 | max_requests_per_minute: int |
57 | 276 | max_concurrent_requests: int |
@@ -84,20 +303,36 @@ def _refill(self): |
84 | 303 |
|
85 | 304 |
|
86 | 305 | # Singleton instance |
87 | | -_rate_limiter_instance = None |
| 306 | +_global_rate_limiter = None |
88 | 307 |
|
89 | 308 |
|
90 | | -def get_rate_limiter(provider: str, config) -> RateLimiter: |
91 | | - """Get rate limiter for the specified provider.""" |
| 309 | +def get_global_rate_limiter() -> RateLimiter: |
| 310 | + """Get the global rate limiter instance.""" |
| 311 | + global _global_rate_limiter |
| 312 | + if _global_rate_limiter is None: |
| 313 | + _global_rate_limiter = RateLimiter() |
| 314 | + return _global_rate_limiter |
| 315 | + |
| 316 | + |
| 317 | +def get_simple_rate_limiter(provider: str, config=None) -> SimpleLimiter: |
| 318 | + """Get legacy rate limiter for the specified provider.""" |
| 319 | + if config is None: |
| 320 | + config = get_config() |
| 321 | + |
92 | 322 | if provider == "shazam": |
93 | | - return RateLimiter( |
| 323 | + return SimpleLimiter( |
94 | 324 | max_requests_per_minute=config.shazam_max_rpm, |
95 | 325 | max_concurrent_requests=config.shazam_max_concurrent, |
96 | 326 | ) |
97 | 327 | elif provider == "acrcloud": |
98 | | - return RateLimiter( |
| 328 | + return SimpleLimiter( |
99 | 329 | max_requests_per_minute=config.acrcloud_max_rpm, |
100 | 330 | max_concurrent_requests=config.acrcloud_max_concurrent, |
101 | 331 | ) |
| 332 | + elif provider == "spotify": |
| 333 | + return SimpleLimiter( |
| 334 | + max_requests_per_minute=getattr(config, "spotify_max_rpm", 120), |
| 335 | + max_concurrent_requests=getattr(config, "spotify_max_concurrent", 20), |
| 336 | + ) |
102 | 337 | else: |
103 | 338 | raise ValueError(f"Unknown provider: {provider}") |
0 commit comments