Skip to content

Commit 290ffe8

Browse files
authored
feat: implement advanced rate limiter with circuit breaker and metric… (#31)
This pull request introduces significant improvements to the rate limiting system, configuration, and utility imports for the track identification service. The main focus is on making rate limiting more robust and configurable, tightening default limits, and updating legacy support. Below are the most important changes grouped by theme: Rate Limiting System Overhaul: Introduced a new RateLimiter class in tracklistify/utils/rate_limiter.py that supports provider-specific limits, circuit breaker functionality, metrics collection, and asynchronous concurrency control. This replaces the previous simpler implementation and adds support for registering providers, alert callbacks, and global rate limiter access. Added a global singleton accessor get_global_rate_limiter and updated legacy support with get_simple_rate_limiter for backward compatibility. The legacy function get_rate_limiter was removed. Configuration Updates: Added global rate limiting configuration options (max_requests_per_minute, max_concurrent_requests) to the TrackIdentificationConfig class in tracklistify/config/base.py, with new defaults set to 25 requests per minute and 2 concurrent requests. Updated .env.example to reflect stricter default values: minimum confidence threshold increased to 0.8, max requests per minute reduced to 25, and max concurrent requests reduced to 2. [1] [2] Utility Imports and Refactoring: Updated tracklistify/utils/__init__.py to use get_simple_rate_limiter instead of the old get_rate_limiter, ensuring all imports match the new rate limiting implementation. These changes collectively enhance the reliability and configurability of rate limiting, improve error handling with circuit breakers, and ensure the codebase uses the latest utility interfaces.
1 parent fc37f22 commit 290ffe8

4 files changed

Lines changed: 259 additions & 19 deletions

File tree

.env.example

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
TRACKLISTIFY_OUTPUT_DIR=.tracklistify/output # Directory for output files
66
TRACKLISTIFY_CACHE_DIR=.tracklistify/cache # Directory for caching processed audio segments
77
TRACKLISTIFY_TEMP_DIR=.tracklistify/temp # Directory for temporary files
8-
TRACKLISTIFY_LOG_DIR=.tracklistify/log # Directory for log files
8+
TRACKLISTIFY_LOG_DIR=.tracklistify/log # Directory for log files
99
TRACKLISTIFY_VERBOSE=false # Enable verbose logging
1010
TRACKLISTIFY_DEBUG=false # Enable debug mode
1111

1212
# Track Identification Settings
1313
TRACKLISTIFY_SEGMENT_LENGTH=60 # Length of audio segments in seconds (10 to 300)
14-
TRACKLISTIFY_MIN_CONFIDENCE=0.0 # Minimum confidence threshold (0.0 to 1.0)
14+
TRACKLISTIFY_MIN_CONFIDENCE=0.8 # Minimum confidence threshold (0.0 to 1.0)
1515
TRACKLISTIFY_TIME_THRESHOLD=30.0 # Time threshold between tracks in seconds (0.0 to 300.0)
1616
TRACKLISTIFY_MAX_DUPLICATES=2 # Maximum number of duplicate tracks (0 to 10)
1717
TRACKLISTIFY_OVERLAP_DURATION=10 # Overlap duration between segments in seconds (0 to 30)
@@ -50,8 +50,8 @@ TRACKLISTIFY_RETRY_MAX_DELAY=30.0 # Maximum retry delay (1.0 t
5050

5151
# Rate Limiting Settings
5252
TRACKLISTIFY_RATE_LIMIT_ENABLED=true # Enable rate limiting
53-
TRACKLISTIFY_MAX_REQUESTS_PER_MINUTE=60 # Global maximum requests per minute (1 to 1000)
54-
TRACKLISTIFY_MAX_CONCURRENT_REQUESTS=10 # Global maximum concurrent requests (1 to 100)
53+
TRACKLISTIFY_MAX_REQUESTS_PER_MINUTE=25 # Global maximum requests per minute (1 to 1000)
54+
TRACKLISTIFY_MAX_CONCURRENT_REQUESTS=2 # Global maximum concurrent requests (1 to 100)
5555

5656
# Circuit Breaker Settings
5757
TRACKLISTIFY_CIRCUIT_BREAKER_ENABLED=true # Enable circuit breaker

tracklistify/config/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ class TrackIdentificationConfig(BaseConfig):
109109
primary_provider: str = field(default="shazam")
110110
fallback_enabled: bool = field(default=True)
111111
fallback_providers: List[str] = field(default_factory=list)
112+
113+
# Global rate limiting settings
114+
max_requests_per_minute: int = field(default=25)
115+
max_concurrent_requests: int = field(default=2)
116+
112117
cache_enabled: bool = field(default=True)
113118
cache_ttl: int = field(default=86400)
114119
cache_max_size: int = field(default=1000000)

tracklistify/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
"""
66

77
from .decorators import memoize
8+
from .identification import IdentificationManager
89
from .logger import get_logger, set_logger
9-
from .rate_limiter import get_rate_limiter
10+
from .rate_limiter import get_simple_rate_limiter
1011
from .validation import validate_input
11-
from .identification import IdentificationManager
1212

1313
__all__ = [
1414
"memoize",
1515
"get_logger",
1616
"set_logger",
17-
"get_rate_limiter",
17+
"get_simple_rate_limiter",
1818
"validate_input",
1919
"IdentificationManager",
2020
]

tracklistify/utils/rate_limiter.py

Lines changed: 247 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
"""
44

55
# Standard library imports
6+
import asyncio
67
import time
78
from dataclasses import dataclass, field
89
from enum import Enum
910
from threading import Lock, Semaphore
10-
from typing import List, Optional, Tuple
11+
from typing import Any, Callable, Dict, List, Optional, Tuple
1112

1213
# Local/package imports
14+
from ..config import get_config
15+
16+
# Rate limiting constants
17+
# 1ms threshold to detect actual rate limit
18+
RATE_LIMIT_DETECTION_THRESHOLD_SECONDS = 0.001
1319

1420

1521
class CircuitState(Enum):
@@ -37,21 +43,234 @@ class RateLimitMetrics:
3743
class ProviderLimits:
3844
"""Rate limits for a specific provider."""
3945

40-
max_requests_per_minute: int = 60
41-
max_concurrent_requests: int = 10
42-
tokens: int = field(default_factory=lambda: 60)
46+
max_requests_per_minute: int = 25 # Default fallback (matches Shazam default)
47+
max_concurrent_requests: int = 1 # Default fallback (matches Shazam default)
48+
tokens: int = field(init=False)
4349
last_update: float = field(default_factory=time.time)
44-
semaphore: Semaphore = field(default_factory=lambda: Semaphore(10))
50+
semaphore: asyncio.Semaphore = field(init=False)
4551
lock: Lock = field(default_factory=Lock)
4652
metrics: RateLimitMetrics = field(default_factory=RateLimitMetrics)
4753
circuit_state: CircuitState = field(default=CircuitState.CLOSED)
4854
circuit_open_time: Optional[float] = None
4955
consecutive_failures: int = 0
5056

57+
def __post_init__(self):
58+
self.tokens = self.max_requests_per_minute
59+
self.semaphore = asyncio.Semaphore(self.max_concurrent_requests)
60+
5161

52-
@dataclass
5362
class RateLimiter:
54-
"""Rate limiter implementation."""
63+
"""Advanced rate limiter with provider management, circuit breaker, metrics."""
64+
65+
def __init__(self, config=None):
66+
self._provider_limits: Dict[Any, ProviderLimits] = {}
67+
self._alert_callbacks: List[Callable[[str], None]] = []
68+
# Use provided config or get global config
69+
self._config = config or get_config()
70+
71+
def register_provider(
72+
self,
73+
provider: Any,
74+
max_requests_per_minute: int = None,
75+
max_concurrent_requests: int = None,
76+
):
77+
"""Register a provider with specific rate limits."""
78+
# Use provided values, or fall back to config values based on provider
79+
if max_requests_per_minute is None or max_concurrent_requests is None:
80+
provider_str = str(provider).lower()
81+
82+
# Get provider-specific limits from config
83+
if provider_str == "shazam":
84+
rpm = max_requests_per_minute or getattr(
85+
self._config, "shazam_max_rpm", 25
86+
)
87+
concurrent = max_concurrent_requests or getattr(
88+
self._config, "shazam_max_concurrent", 1
89+
)
90+
elif provider_str == "acrcloud":
91+
rpm = max_requests_per_minute or getattr(
92+
self._config, "acrcloud_max_rpm", 30
93+
)
94+
concurrent = max_concurrent_requests or getattr(
95+
self._config, "acrcloud_max_concurrent", 5
96+
)
97+
elif provider_str == "spotify":
98+
rpm = max_requests_per_minute or getattr(
99+
self._config, "spotify_max_rpm", 120
100+
)
101+
concurrent = max_concurrent_requests or getattr(
102+
self._config, "spotify_max_concurrent", 20
103+
)
104+
else:
105+
# Fall back to global config or defaults
106+
rpm = max_requests_per_minute or getattr(
107+
self._config, "max_requests_per_minute", 25
108+
)
109+
concurrent = max_concurrent_requests or getattr(
110+
self._config, "max_concurrent_requests", 2
111+
)
112+
else:
113+
rpm = max_requests_per_minute
114+
concurrent = max_concurrent_requests
115+
116+
self._provider_limits[provider] = ProviderLimits(
117+
max_requests_per_minute=rpm,
118+
max_concurrent_requests=concurrent,
119+
)
120+
121+
def register_alert_callback(self, callback: Callable[[str], None]):
122+
"""Register a callback for rate limiting alerts."""
123+
self._alert_callbacks.append(callback)
124+
125+
def _send_alert(self, message: str):
126+
"""Send alert to all registered callbacks."""
127+
for callback in self._alert_callbacks:
128+
callback(message)
129+
130+
async def acquire(self, provider: Any, timeout: float = 30.0) -> bool:
131+
"""Acquire permission to make a request."""
132+
if provider not in self._provider_limits:
133+
self.register_provider(provider)
134+
135+
limits = self._provider_limits[provider]
136+
137+
# Check circuit breaker first (don't count rejected requests)
138+
circuit_breaker_enabled = getattr(self._config, "circuit_breaker_enabled", True)
139+
if circuit_breaker_enabled and limits.circuit_state == CircuitState.OPEN:
140+
circuit_reset_timeout = getattr(
141+
self._config, "circuit_breaker_reset_timeout", 60.0
142+
)
143+
if (
144+
limits.circuit_open_time
145+
and time.time() - limits.circuit_open_time > circuit_reset_timeout
146+
):
147+
limits.circuit_state = CircuitState.HALF_OPEN
148+
else:
149+
return False
150+
151+
# Try to acquire semaphore for concurrent requests
152+
try:
153+
# Always attempt to acquire the semaphore with a timeout
154+
start_time = time.time()
155+
await asyncio.wait_for(limits.semaphore.acquire(), timeout=timeout)
156+
wait_time = time.time() - start_time
157+
limits.metrics.total_wait_time += wait_time
158+
# Note: Don't record semaphore waits as rate limit windows
159+
# This is concurrency control, not rate limiting
160+
except asyncio.TimeoutError:
161+
return False
162+
163+
# At this point, we have semaphore access and will process the request
164+
limits.metrics.total_requests += 1
165+
166+
# Check if rate limiting is enabled
167+
rate_limit_enabled = getattr(self._config, "rate_limit_enabled", True)
168+
if not rate_limit_enabled:
169+
return True
170+
171+
# Check rate limiting tokens
172+
token_wait_start = time.time()
173+
174+
while time.time() - token_wait_start < timeout:
175+
with limits.lock:
176+
self._refill_tokens(limits)
177+
if limits.tokens > 0:
178+
limits.tokens -= 1
179+
# Record metrics only if we had to wait for tokens (rate limiting)
180+
wait_time = time.time() - token_wait_start
181+
if wait_time >= RATE_LIMIT_DETECTION_THRESHOLD_SECONDS:
182+
# Count successful requests that were rate-limited
183+
limits.metrics.rate_limited_requests += 1
184+
limits.metrics.last_rate_limit = time.time()
185+
limits.metrics.rate_limit_windows.append(
186+
(token_wait_start, time.time())
187+
)
188+
return True
189+
190+
# Wait a short time before checking again
191+
await asyncio.sleep(0.01)
192+
193+
# Timeout exceeded - this is a rate limiting failure
194+
# Record the window and update last_rate_limit but don't increment
195+
# rate_limited_requests (that's only for successful requests that waited)
196+
limits.metrics.last_rate_limit = time.time()
197+
limits.metrics.rate_limit_windows.append((token_wait_start, time.time()))
198+
199+
limits.semaphore.release()
200+
return False
201+
202+
def release(self, provider: Any):
203+
"""Release a concurrent request slot."""
204+
if provider in self._provider_limits:
205+
limits = self._provider_limits[provider]
206+
limits.semaphore.release()
207+
208+
def _refill_tokens(self, limits: ProviderLimits):
209+
"""Refill rate limiting tokens."""
210+
now = time.time()
211+
elapsed = now - limits.last_update
212+
if elapsed >= 1.0: # Refill every second
213+
tokens_to_add = int(elapsed * (limits.max_requests_per_minute / 60))
214+
if tokens_to_add > 0:
215+
limits.tokens = min(
216+
limits.max_requests_per_minute, limits.tokens + tokens_to_add
217+
)
218+
limits.last_update = now
219+
220+
def _update_circuit_breaker(self, provider: Any, success: bool):
221+
"""Update circuit breaker state based on request success."""
222+
if provider not in self._provider_limits:
223+
return
224+
225+
limits = self._provider_limits[provider]
226+
circuit_breaker_enabled = getattr(self._config, "circuit_breaker_enabled", True)
227+
228+
if not circuit_breaker_enabled:
229+
return
230+
231+
if success:
232+
limits.consecutive_failures = 0
233+
if limits.circuit_state == CircuitState.HALF_OPEN:
234+
limits.circuit_state = CircuitState.CLOSED
235+
else:
236+
limits.consecutive_failures += 1
237+
circuit_threshold = getattr(self._config, "circuit_breaker_threshold", 5)
238+
if (
239+
limits.consecutive_failures >= circuit_threshold
240+
and limits.circuit_state == CircuitState.CLOSED
241+
):
242+
limits.circuit_state = CircuitState.OPEN
243+
limits.circuit_open_time = time.time()
244+
limits.metrics.circuit_trips += 1
245+
limits.metrics.last_circuit_trip = time.time()
246+
self._send_alert(
247+
message=f"Circuit breaker opened for provider {provider} "
248+
f"after {limits.consecutive_failures} failures"
249+
)
250+
251+
def get_metrics(self, provider: Any) -> Dict[str, Any]:
252+
"""Get metrics for a provider."""
253+
if provider not in self._provider_limits:
254+
return {}
255+
256+
limits = self._provider_limits[provider]
257+
return {
258+
"total_requests": limits.metrics.total_requests,
259+
"rate_limited_requests": limits.metrics.rate_limited_requests,
260+
"total_wait_time": limits.metrics.total_wait_time,
261+
"last_rate_limit": limits.metrics.last_rate_limit,
262+
"rate_limit_windows": limits.metrics.rate_limit_windows,
263+
"circuit_trips": limits.metrics.circuit_trips,
264+
"last_circuit_trip": limits.metrics.last_circuit_trip,
265+
"circuit_state": limits.circuit_state.value,
266+
"current_tokens": limits.tokens,
267+
}
268+
269+
270+
# Legacy support for the simple RateLimiter
271+
@dataclass
272+
class SimpleLimiter:
273+
"""Simple rate limiter implementation."""
55274

56275
max_requests_per_minute: int
57276
max_concurrent_requests: int
@@ -84,20 +303,36 @@ def _refill(self):
84303

85304

86305
# Singleton instance
87-
_rate_limiter_instance = None
306+
_global_rate_limiter = None
88307

89308

90-
def get_rate_limiter(provider: str, config) -> RateLimiter:
91-
"""Get rate limiter for the specified provider."""
309+
def get_global_rate_limiter() -> RateLimiter:
310+
"""Get the global rate limiter instance."""
311+
global _global_rate_limiter
312+
if _global_rate_limiter is None:
313+
_global_rate_limiter = RateLimiter()
314+
return _global_rate_limiter
315+
316+
317+
def get_simple_rate_limiter(provider: str, config=None) -> SimpleLimiter:
318+
"""Get legacy rate limiter for the specified provider."""
319+
if config is None:
320+
config = get_config()
321+
92322
if provider == "shazam":
93-
return RateLimiter(
323+
return SimpleLimiter(
94324
max_requests_per_minute=config.shazam_max_rpm,
95325
max_concurrent_requests=config.shazam_max_concurrent,
96326
)
97327
elif provider == "acrcloud":
98-
return RateLimiter(
328+
return SimpleLimiter(
99329
max_requests_per_minute=config.acrcloud_max_rpm,
100330
max_concurrent_requests=config.acrcloud_max_concurrent,
101331
)
332+
elif provider == "spotify":
333+
return SimpleLimiter(
334+
max_requests_per_minute=getattr(config, "spotify_max_rpm", 120),
335+
max_concurrent_requests=getattr(config, "spotify_max_concurrent", 20),
336+
)
102337
else:
103338
raise ValueError(f"Unknown provider: {provider}")

0 commit comments

Comments
 (0)