1717
1818from __future__ import annotations
1919
20+ from pathlib import Path
21+
2022import redis
2123
2224from ..backend import RateLimiterBackend
2325
26+ _SCRIPTS = {
27+ path .stem : path .read_text ()
28+ for path in (Path (__file__ ).parent / "redis" ).glob ("*.lua" )
29+ }
30+
2431
2532class RedisBackend (RateLimiterBackend ):
2633 """A rate limiter backend for Redis_.
@@ -41,68 +48,26 @@ def __init__(self, *, client=None, url=None, **parameters):
4148 parameters ["connection_pool" ] = redis .ConnectionPool .from_url (url )
4249
4350 self .client = client or redis .Redis (** parameters )
51+ self .scripts = {
52+ name : self .client .register_script (text ) for name , text in _SCRIPTS .items ()
53+ }
4454
4555 def add (self , key , value , ttl ):
4656 return bool (self .client .set (key , value , px = ttl , nx = True ))
4757
4858 def incr (self , key , amount , maximum , ttl ):
49- with self .client .pipeline () as pipe :
50- while True :
51- try :
52- pipe .watch (key )
53- value = int (pipe .get (key ) or b"0" )
54- value += amount
55- if value > maximum :
56- return False
57-
58- pipe .multi ()
59- pipe .set (key , value , px = ttl )
60- pipe .execute ()
61- return True
62- except redis .WatchError :
63- continue
59+ incr_up_to = self .scripts ["incr_up_to" ]
60+ return incr_up_to ([key ], [amount , maximum , ttl ]) == 1
6461
6562 def decr (self , key , amount , minimum , ttl ):
66- with self .client .pipeline () as pipe :
67- while True :
68- try :
69- pipe .watch (key )
70- value = int (pipe .get (key ) or b"0" )
71- value -= amount
72- if value < minimum :
73- return False
74-
75- pipe .multi ()
76- pipe .set (key , value , px = ttl )
77- pipe .execute ()
78- return True
79- except redis .WatchError :
80- continue
63+ decr_down_to = self .scripts ["decr_down_to" ]
64+ return decr_down_to ([key ], [amount , minimum , ttl ]) == 1
8165
8266 def incr_and_sum (self , key , keys , amount , maximum , ttl ):
83- with self .client .pipeline () as pipe :
84- while True :
85- try :
86- # TODO: Drop non-callable keys in Dramatiq v2.
87- key_list = keys () if callable (keys ) else keys
88- pipe .watch (key , * key_list )
89- value = int (pipe .get (key ) or b"0" )
90- value += amount
91- if value > maximum :
92- return False
93-
94- # Fetch keys again to account for net/server latency.
95- values = pipe .mget (keys () if callable (keys ) else keys )
96- total = amount + sum (int (n ) for n in values if n )
97- if total > maximum :
98- return False
99-
100- pipe .multi ()
101- pipe .set (key , value , px = ttl )
102- pipe .execute ()
103- return True
104- except redis .WatchError :
105- continue
67+ # TODO: Drop non-callable keys in Dramatiq v2.
68+ keys_list = keys () if callable (keys ) else keys
69+ incr_up_to_with_sum_check = self .scripts ["incr_up_to_with_sum_check" ]
70+ return incr_up_to_with_sum_check ([key , * keys_list ], [amount , maximum , ttl ]) == 1
10671
10772 def wait (self , key , timeout ):
10873 assert timeout is None or timeout >= 1000 , "wait timeouts must be >= 1000"
0 commit comments