Skip to content

Commit 58e1267

Browse files
committed
Rewrite Redis RateLimiterBackend with Lua scripts
1 parent 9037687 commit 58e1267

File tree

4 files changed

+94
-53
lines changed

4 files changed

+94
-53
lines changed

dramatiq/rate_limits/backends/redis.py

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717

1818
from __future__ import annotations
1919

20+
from pathlib import Path
21+
2022
import redis
2123

2224
from ..backend import RateLimiterBackend
2325

26+
_SCRIPTS = {path.stem: path.read_text() for path in (Path(__file__).parent / "redis").glob("*.lua")}
27+
2428

2529
class RedisBackend(RateLimiterBackend):
2630
"""A rate limiter backend for Redis_.
@@ -42,68 +46,24 @@ def __init__(self, *, client=None, url=None, **parameters):
4246

4347
# TODO: Replace usages of StrictRedis (redis-py 2.x) with Redis in Dramatiq 2.0.
4448
self.client = client or redis.StrictRedis(**parameters)
49+
self.scripts = {name: self.client.register_script(text) for name, text in _SCRIPTS.items()}
4550

4651
def add(self, key, value, ttl):
4752
return bool(self.client.set(key, value, px=ttl, nx=True))
4853

4954
def incr(self, key, amount, maximum, ttl):
50-
with self.client.pipeline() as pipe:
51-
while True:
52-
try:
53-
pipe.watch(key)
54-
value = int(pipe.get(key) or b"0")
55-
value += amount
56-
if value > maximum:
57-
return False
58-
59-
pipe.multi()
60-
pipe.set(key, value, px=ttl)
61-
pipe.execute()
62-
return True
63-
except redis.WatchError:
64-
continue
55+
incr_up_to = self.scripts["incr_up_to"]
56+
return incr_up_to([key], [amount, maximum, ttl]) == 1
6557

6658
def decr(self, key, amount, minimum, ttl):
67-
with self.client.pipeline() as pipe:
68-
while True:
69-
try:
70-
pipe.watch(key)
71-
value = int(pipe.get(key) or b"0")
72-
value -= amount
73-
if value < minimum:
74-
return False
75-
76-
pipe.multi()
77-
pipe.set(key, value, px=ttl)
78-
pipe.execute()
79-
return True
80-
except redis.WatchError:
81-
continue
59+
decr_down_to = self.scripts["decr_down_to"]
60+
return decr_down_to([key], [amount, minimum, ttl]) == 1
8261

8362
def incr_and_sum(self, key, keys, amount, maximum, ttl):
84-
with self.client.pipeline() as pipe:
85-
while True:
86-
try:
87-
# TODO: Drop non-callable keys in Dramatiq v2.
88-
key_list = keys() if callable(keys) else keys
89-
pipe.watch(key, *key_list)
90-
value = int(pipe.get(key) or b"0")
91-
value += amount
92-
if value > maximum:
93-
return False
94-
95-
# Fetch keys again to account for net/server latency.
96-
values = pipe.mget(keys() if callable(keys) else keys)
97-
total = amount + sum(int(n) for n in values if n)
98-
if total > maximum:
99-
return False
100-
101-
pipe.multi()
102-
pipe.set(key, value, px=ttl)
103-
pipe.execute()
104-
return True
105-
except redis.WatchError:
106-
continue
63+
# TODO: Drop non-callable keys in Dramatiq v2.
64+
keys_list = keys() if callable(keys) else keys
65+
incr_up_to_with_sum_check = self.scripts["incr_up_to_with_sum_check"]
66+
return incr_up_to_with_sum_check([key, *keys_list], [amount, maximum, ttl]) == 1
10767

10868
def wait(self, key, timeout):
10969
assert timeout is None or timeout >= 1000, "wait timeouts must be >= 1000"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- decr_down_to(
2+
-- keys=[key]
3+
-- args=[amount, minimum, ttl]
4+
-- )
5+
--
6+
-- Decrement `key` by `amount`, unless the resulting value is less than `minimum`.
7+
8+
local key = KEYS[1]
9+
local amount = tonumber(ARGV[1])
10+
local minimum = tonumber(ARGV[2])
11+
local ttl = tonumber(ARGV[3])
12+
13+
local value = redis.call('GET', key) or 0
14+
if value - amount < minimum then
15+
return false
16+
end
17+
18+
redis.call('SET', key, value - amount, 'PX', ttl)
19+
return true
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- incr_up_to(
2+
-- keys=[key]
3+
-- args=[amount, maximum, ttl]
4+
-- )
5+
--
6+
-- Increment `key` by `amount`, unless the resulting value is greater than `maximum`.
7+
8+
local key = KEYS[1]
9+
local amount = tonumber(ARGV[1])
10+
local maximum = tonumber(ARGV[2])
11+
local ttl = tonumber(ARGV[3])
12+
13+
local value = redis.call('GET', key) or 0
14+
if value + amount > maximum then
15+
return false
16+
end
17+
18+
redis.call('SET', key, value + amount, 'PX', ttl)
19+
return true
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
-- incr_up_to_with_sum_check(
2+
-- keys=[key, *keys]
3+
-- args=[amount, maximum, ttl]
4+
-- )
5+
--
6+
-- Atomically increment `key` by `amount`, unless:
7+
-- - the incremented value is greater than `maximum`, or
8+
-- - the incremented sum of `keys` is greater than `maximum`.
9+
10+
-- split the key list into the first - `key` and the rest - `keys`
11+
local key = KEYS[1]
12+
local keys = {}
13+
for i, k in ipairs(KEYS) do
14+
if i > 1 then
15+
keys[i - 1] = KEYS[i]
16+
end
17+
end
18+
19+
local amount = tonumber(ARGV[1])
20+
local maximum = tonumber(ARGV[2])
21+
local ttl = tonumber(ARGV[3])
22+
23+
-- check if `key` can be incremented, bail if not
24+
local value = redis.call('GET', key) or 0
25+
if value + amount > maximum then
26+
return false
27+
end
28+
29+
-- check if sum of `keys` can be incremented, bail if not
30+
local values = redis.call('MGET', unpack(keys))
31+
local sum = 0
32+
for _, v in ipairs(values) do
33+
if v then
34+
sum = sum + tonumber(v)
35+
end
36+
end
37+
if sum + amount > maximum then
38+
return false
39+
end
40+
41+
-- increment `key` if we got this far
42+
redis.call('SET', key, value + amount, 'PX', ttl)
43+
return true

0 commit comments

Comments
 (0)