1717
1818from __future__ import annotations
1919
20+ from pathlib import Path
21+
2022import redis
2123
2224from ..backend import RateLimiterBackend
2325
26+ _SCRIPTS = {path .stem : path .read_text () for path in (Path (__file__ ).parent / "redis" ).glob ("*.lua" )}
27+
2428
2529class RedisBackend (RateLimiterBackend ):
2630 """A rate limiter backend for Redis_.
@@ -42,68 +46,24 @@ def __init__(self, *, client=None, url=None, **parameters):
4246
4347 # TODO: Replace usages of StrictRedis (redis-py 2.x) with Redis in Dramatiq 2.0.
4448 self .client = client or redis .StrictRedis (** parameters )
49+ self .scripts = {name : self .client .register_script (text ) for name , text in _SCRIPTS .items ()}
4550
4651 def add (self , key , value , ttl ):
4752 return bool (self .client .set (key , value , px = ttl , nx = True ))
4853
4954 def incr (self , key , amount , maximum , ttl ):
50- with self .client .pipeline () as pipe :
51- while True :
52- try :
53- pipe .watch (key )
54- value = int (pipe .get (key ) or b"0" )
55- value += amount
56- if value > maximum :
57- return False
58-
59- pipe .multi ()
60- pipe .set (key , value , px = ttl )
61- pipe .execute ()
62- return True
63- except redis .WatchError :
64- continue
55+ incr_up_to = self .scripts ["incr_up_to" ]
56+ return incr_up_to ([key ], [amount , maximum , ttl ]) == 1
6557
6658 def decr (self , key , amount , minimum , ttl ):
67- with self .client .pipeline () as pipe :
68- while True :
69- try :
70- pipe .watch (key )
71- value = int (pipe .get (key ) or b"0" )
72- value -= amount
73- if value < minimum :
74- return False
75-
76- pipe .multi ()
77- pipe .set (key , value , px = ttl )
78- pipe .execute ()
79- return True
80- except redis .WatchError :
81- continue
59+ decr_down_to = self .scripts ["decr_down_to" ]
60+ return decr_down_to ([key ], [amount , minimum , ttl ]) == 1
8261
8362 def incr_and_sum (self , key , keys , amount , maximum , ttl ):
84- with self .client .pipeline () as pipe :
85- while True :
86- try :
87- # TODO: Drop non-callable keys in Dramatiq v2.
88- key_list = keys () if callable (keys ) else keys
89- pipe .watch (key , * key_list )
90- value = int (pipe .get (key ) or b"0" )
91- value += amount
92- if value > maximum :
93- return False
94-
95- # Fetch keys again to account for net/server latency.
96- values = pipe .mget (keys () if callable (keys ) else keys )
97- total = amount + sum (int (n ) for n in values if n )
98- if total > maximum :
99- return False
100-
101- pipe .multi ()
102- pipe .set (key , value , px = ttl )
103- pipe .execute ()
104- return True
105- except redis .WatchError :
106- continue
63+ # TODO: Drop non-callable keys in Dramatiq v2.
64+ keys_list = keys () if callable (keys ) else keys
65+ incr_up_to_with_sum_check = self .scripts ["incr_up_to_with_sum_check" ]
66+ return incr_up_to_with_sum_check ([key , * keys_list ], [amount , maximum , ttl ]) == 1
10767
10868 def wait (self , key , timeout ):
10969 assert timeout is None or timeout >= 1000 , "wait timeouts must be >= 1000"
0 commit comments