1717
1818from __future__ import annotations
1919
20+ from pathlib import Path
21+
2022import redis
2123
2224from ..backend import RateLimiterBackend
2325
26+ _SCRIPTS = {
27+ path .stem : path .read_text ()
28+ for path in (Path (__file__ ).parent / "redis" ).glob ("*.lua" )
29+ }
30+
2431
2532class RedisBackend (RateLimiterBackend ):
2633 """A rate limiter backend for Redis_.
@@ -42,68 +49,26 @@ def __init__(self, *, client=None, url=None, **parameters):
4249
4350 # TODO: Replace usages of StrictRedis (redis-py 2.x) with Redis in Dramatiq 2.0.
4451 self .client = client or redis .StrictRedis (** parameters )
52+ self .scripts = {
53+ name : self .client .register_script (text ) for name , text in _SCRIPTS .items ()
54+ }
4555
4656 def add (self , key , value , ttl ):
4757 return bool (self .client .set (key , value , px = ttl , nx = True ))
4858
4959 def incr (self , key , amount , maximum , ttl ):
50- with self .client .pipeline () as pipe :
51- while True :
52- try :
53- pipe .watch (key )
54- value = int (pipe .get (key ) or b"0" )
55- value += amount
56- if value > maximum :
57- return False
58-
59- pipe .multi ()
60- pipe .set (key , value , px = ttl )
61- pipe .execute ()
62- return True
63- except redis .WatchError :
64- continue
60+ incr_up_to = self .scripts ["incr_up_to" ]
61+ return incr_up_to ([key ], [amount , maximum , ttl ]) == 1
6562
6663 def decr (self , key , amount , minimum , ttl ):
67- with self .client .pipeline () as pipe :
68- while True :
69- try :
70- pipe .watch (key )
71- value = int (pipe .get (key ) or b"0" )
72- value -= amount
73- if value < minimum :
74- return False
75-
76- pipe .multi ()
77- pipe .set (key , value , px = ttl )
78- pipe .execute ()
79- return True
80- except redis .WatchError :
81- continue
64+ decr_down_to = self .scripts ["decr_down_to" ]
65+ return decr_down_to ([key ], [amount , minimum , ttl ]) == 1
8266
8367 def incr_and_sum (self , key , keys , amount , maximum , ttl ):
84- with self .client .pipeline () as pipe :
85- while True :
86- try :
87- # TODO: Drop non-callable keys in Dramatiq v2.
88- key_list = keys () if callable (keys ) else keys
89- pipe .watch (key , * key_list )
90- value = int (pipe .get (key ) or b"0" )
91- value += amount
92- if value > maximum :
93- return False
94-
95- # Fetch keys again to account for net/server latency.
96- values = pipe .mget (keys () if callable (keys ) else keys )
97- total = amount + sum (int (n ) for n in values if n )
98- if total > maximum :
99- return False
100-
101- pipe .multi ()
102- pipe .set (key , value , px = ttl )
103- pipe .execute ()
104- return True
105- except redis .WatchError :
106- continue
68+ # TODO: Drop non-callable keys in Dramatiq v2.
69+ keys_list = keys () if callable (keys ) else keys
70+ incr_up_to_with_sum_check = self .scripts ["incr_up_to_with_sum_check" ]
71+ return incr_up_to_with_sum_check ([key , * keys_list ], [amount , maximum , ttl ]) == 1
10772
10873 def wait (self , key , timeout ):
10974 assert timeout is None or timeout >= 1000 , "wait timeouts must be >= 1000"
0 commit comments