Skip to content

Commit 0c48baa

Browse files
committed
Rewrite Redis RateLimiterBackend with Lua scripts
1 parent 62b2e9b commit 0c48baa

File tree

4 files changed

+99
-53
lines changed

4 files changed

+99
-53
lines changed

dramatiq/rate_limits/backends/redis.py

Lines changed: 18 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@
1717

1818
from __future__ import annotations
1919

20+
from pathlib import Path
21+
2022
import redis
2123

2224
from ..backend import RateLimiterBackend
2325

26+
_SCRIPTS = {
27+
path.stem: path.read_text()
28+
for path in (Path(__file__).parent / "redis").glob("*.lua")
29+
}
30+
2431

2532
class RedisBackend(RateLimiterBackend):
2633
"""A rate limiter backend for Redis_.
@@ -41,68 +48,26 @@ def __init__(self, *, client=None, url=None, **parameters):
4148
parameters["connection_pool"] = redis.ConnectionPool.from_url(url)
4249

4350
self.client = client or redis.Redis(**parameters)
51+
self.scripts = {
52+
name: self.client.register_script(text) for name, text in _SCRIPTS.items()
53+
}
4454

4555
def add(self, key, value, ttl):
4656
return bool(self.client.set(key, value, px=ttl, nx=True))
4757

4858
def incr(self, key, amount, maximum, ttl):
49-
with self.client.pipeline() as pipe:
50-
while True:
51-
try:
52-
pipe.watch(key)
53-
value = int(pipe.get(key) or b"0")
54-
value += amount
55-
if value > maximum:
56-
return False
57-
58-
pipe.multi()
59-
pipe.set(key, value, px=ttl)
60-
pipe.execute()
61-
return True
62-
except redis.WatchError:
63-
continue
59+
incr_up_to = self.scripts["incr_up_to"]
60+
return incr_up_to([key], [amount, maximum, ttl]) == 1
6461

6562
def decr(self, key, amount, minimum, ttl):
66-
with self.client.pipeline() as pipe:
67-
while True:
68-
try:
69-
pipe.watch(key)
70-
value = int(pipe.get(key) or b"0")
71-
value -= amount
72-
if value < minimum:
73-
return False
74-
75-
pipe.multi()
76-
pipe.set(key, value, px=ttl)
77-
pipe.execute()
78-
return True
79-
except redis.WatchError:
80-
continue
63+
decr_down_to = self.scripts["decr_down_to"]
64+
return decr_down_to([key], [amount, minimum, ttl]) == 1
8165

8266
def incr_and_sum(self, key, keys, amount, maximum, ttl):
83-
with self.client.pipeline() as pipe:
84-
while True:
85-
try:
86-
# TODO: Drop non-callable keys in Dramatiq v2.
87-
key_list = keys() if callable(keys) else keys
88-
pipe.watch(key, *key_list)
89-
value = int(pipe.get(key) or b"0")
90-
value += amount
91-
if value > maximum:
92-
return False
93-
94-
# Fetch keys again to account for net/server latency.
95-
values = pipe.mget(keys() if callable(keys) else keys)
96-
total = amount + sum(int(n) for n in values if n)
97-
if total > maximum:
98-
return False
99-
100-
pipe.multi()
101-
pipe.set(key, value, px=ttl)
102-
pipe.execute()
103-
return True
104-
except redis.WatchError:
105-
continue
67+
# TODO: Drop non-callable keys in Dramatiq v2.
68+
keys_list = keys() if callable(keys) else keys
69+
incr_up_to_with_sum_check = self.scripts["incr_up_to_with_sum_check"]
70+
return incr_up_to_with_sum_check([key, *keys_list], [amount, maximum, ttl]) == 1
10671

10772
def wait(self, key, timeout):
10873
assert timeout is None or timeout >= 1000, "wait timeouts must be >= 1000"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- decr_down_to(
2+
-- keys=[key]
3+
-- args=[amount, minimum, ttl]
4+
-- )
5+
--
6+
-- Decrement `key` by `amount`, unless the resulting value is less than `minimum`.
7+
8+
local key = KEYS[1]
9+
local amount = tonumber(ARGV[1])
10+
local minimum = tonumber(ARGV[2])
11+
local ttl = tonumber(ARGV[3])
12+
13+
local value = redis.call('GET', key) or 0
14+
if value - amount < minimum then
15+
return false
16+
end
17+
18+
redis.call('SET', key, value - amount, 'PX', ttl)
19+
return true
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- incr_up_to(
2+
-- keys=[key]
3+
-- args=[amount, maximum, ttl]
4+
-- )
5+
--
6+
-- Increment `key` by `amount`, unless the resulting value is greater than `maximum`.
7+
8+
local key = KEYS[1]
9+
local amount = tonumber(ARGV[1])
10+
local maximum = tonumber(ARGV[2])
11+
local ttl = tonumber(ARGV[3])
12+
13+
local value = redis.call('GET', key) or 0
14+
if value + amount > maximum then
15+
return false
16+
end
17+
18+
redis.call('SET', key, value + amount, 'PX', ttl)
19+
return true
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
-- incr_up_to_with_sum_check(
2+
-- keys=[key, *keys]
3+
-- args=[amount, maximum, ttl]
4+
-- )
5+
--
6+
-- Atomically increment `key` by `amount`, unless:
7+
-- - the incremented value is greater than `maximum`, or
8+
-- - the incremented sum of `keys` is greater than `maximum`.
9+
10+
-- split the key list into the first - `key` and the rest - `keys`
11+
local key = KEYS[1]
12+
local keys = {}
13+
for i, k in ipairs(KEYS) do
14+
if i > 1 then
15+
keys[i - 1] = KEYS[i]
16+
end
17+
end
18+
19+
local amount = tonumber(ARGV[1])
20+
local maximum = tonumber(ARGV[2])
21+
local ttl = tonumber(ARGV[3])
22+
23+
-- check if `key` can be incremented, bail if not
24+
local value = redis.call('GET', key) or 0
25+
if value + amount > maximum then
26+
return false
27+
end
28+
29+
-- check if sum of `keys` can be incremented, bail if not
30+
local values = redis.call('MGET', unpack(keys))
31+
local sum = 0
32+
for _, v in ipairs(values) do
33+
if v then
34+
sum = sum + tonumber(v)
35+
end
36+
end
37+
if sum + amount > maximum then
38+
return false
39+
end
40+
41+
-- increment `key` if we got this far
42+
redis.call('SET', key, value + amount, 'PX', ttl)
43+
return true

0 commit comments

Comments
 (0)