Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dramatiq/rate_limits/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def decr(self, key, amount, minimum, ttl): # pragma: no cover
"""
raise NotImplementedError

def format_key_variants(self, key, variants): # pragma: no cover
"""Build a list of related key names from a "base" key name and and a list of
distinct "variants" that can act as e.g. name suffixes.
Parameters:
key(str): The base key name.
variants(list[str]): Distinct values to incorporate into the resulting names.
Returns:
list[str]: The list of resulting key names.
"""
return [f"{key}@{variant}" for variant in variants]

def incr_and_sum(self, key, keys, amount, maximum, ttl): # pragma: no cover
"""Atomically increment a key unless the sum of keys is greater
than the given maximum.
Expand Down
78 changes: 25 additions & 53 deletions dramatiq/rate_limits/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@

from __future__ import annotations

from pathlib import Path

import redis

from ..backend import RateLimiterBackend

_SCRIPTS = {
path.stem: path.read_text()
for path in (Path(__file__).parent / "redis").glob("*.lua")
}


class RedisBackend(RateLimiterBackend):
"""A rate limiter backend for Redis_.
Expand All @@ -41,68 +48,33 @@ def __init__(self, *, client=None, url=None, **parameters):
parameters["connection_pool"] = redis.ConnectionPool.from_url(url)

self.client = client or redis.Redis(**parameters)
self.scripts = {
name: self.client.register_script(text) for name, text in _SCRIPTS.items()
}

def add(self, key, value, ttl):
return bool(self.client.set(key, value, px=ttl, nx=True))

def incr(self, key, amount, maximum, ttl):
with self.client.pipeline() as pipe:
while True:
try:
pipe.watch(key)
value = int(pipe.get(key) or b"0")
value += amount
if value > maximum:
return False

pipe.multi()
pipe.set(key, value, px=ttl)
pipe.execute()
return True
except redis.WatchError:
continue
incr_up_to = self.scripts["incr_up_to"]
return incr_up_to([key], [amount, maximum, ttl]) == 1

def decr(self, key, amount, minimum, ttl):
with self.client.pipeline() as pipe:
while True:
try:
pipe.watch(key)
value = int(pipe.get(key) or b"0")
value -= amount
if value < minimum:
return False

pipe.multi()
pipe.set(key, value, px=ttl)
pipe.execute()
return True
except redis.WatchError:
continue
decr_down_to = self.scripts["decr_down_to"]
return decr_down_to([key], [amount, minimum, ttl]) == 1

def format_key_variants(self, key, variants):
# NOTE the extra { } around the key - this is to make use of Hash tags [0]
# to make sure that multi-key commands execute on the keys in same hash slots.
# This helps avoid a ClusterCrossSlotError in case this redis is a Redis Cluster.
# [0]: https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec/#keys-hash-tags
return [f"{{{key}}}@{variant}" for variant in variants]

def incr_and_sum(self, key, keys, amount, maximum, ttl):
with self.client.pipeline() as pipe:
while True:
try:
# TODO: Drop non-callable keys in Dramatiq v2.
key_list = keys() if callable(keys) else keys
pipe.watch(key, *key_list)
value = int(pipe.get(key) or b"0")
value += amount
if value > maximum:
return False

# Fetch keys again to account for net/server latency.
values = pipe.mget(keys() if callable(keys) else keys)
total = amount + sum(int(n) for n in values if n)
if total > maximum:
return False

pipe.multi()
pipe.set(key, value, px=ttl)
pipe.execute()
return True
except redis.WatchError:
continue
# TODO: Drop non-callable keys in Dramatiq v2.
keys_list = keys() if callable(keys) else keys
incr_up_to_with_sum_check = self.scripts["incr_up_to_with_sum_check"]
return incr_up_to_with_sum_check([key, *keys_list], [amount, maximum, ttl]) == 1

def wait(self, key, timeout):
assert timeout is None or timeout >= 1000, "wait timeouts must be >= 1000"
Expand Down
19 changes: 19 additions & 0 deletions dramatiq/rate_limits/backends/redis/decr_down_to.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- decr_down_to(
-- keys=[key]
-- args=[amount, minimum, ttl]
-- )
--
-- Decrement `key` by `amount`, unless the resulting value is less than `minimum`.

local key = KEYS[1]
local amount = tonumber(ARGV[1])
local minimum = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])

local value = redis.call('GET', key) or 0
if value - amount < minimum then
return false
end

redis.call('SET', key, value - amount, 'PX', ttl)
return true
19 changes: 19 additions & 0 deletions dramatiq/rate_limits/backends/redis/incr_up_to.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- incr_up_to(
-- keys=[key]
-- args=[amount, maximum, ttl]
-- )
--
-- Increment `key` by `amount`, unless the resulting value is greater than `maximum`.

local key = KEYS[1]
local amount = tonumber(ARGV[1])
local maximum = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])

local value = redis.call('GET', key) or 0
if value + amount > maximum then
return false
end

redis.call('SET', key, value + amount, 'PX', ttl)
return true
43 changes: 43 additions & 0 deletions dramatiq/rate_limits/backends/redis/incr_up_to_with_sum_check.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
-- incr_up_to_with_sum_check(
-- keys=[key, *keys]
-- args=[amount, maximum, ttl]
-- )
--
-- Atomically increment `key` by `amount`, unless:
-- - the incremented value is greater than `maximum`, or
-- - the incremented sum of `keys` is greater than `maximum`.

-- split the key list into the first - `key` and the rest - `keys`
local key = KEYS[1]
local keys = {}
for i, k in ipairs(KEYS) do
if i > 1 then
keys[i - 1] = KEYS[i]
end
end

local amount = tonumber(ARGV[1])
local maximum = tonumber(ARGV[2])
local ttl = tonumber(ARGV[3])

-- check if `key` can be incremented, bail if not
local value = redis.call('GET', key) or 0
if value + amount > maximum then
return false
end

-- check if sum of `keys` can be incremented, bail if not
local values = redis.call('MGET', unpack(keys))
local sum = 0
for _, v in ipairs(values) do
if v then
sum = sum + tonumber(v)
end
end
if sum + amount > maximum then
return false
end

-- increment `key` if we got this far
redis.call('SET', key, value + amount, 'PX', ttl)
return true
2 changes: 1 addition & 1 deletion dramatiq/rate_limits/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, backend, key, *, limit=1, window=1):

def _get_keys(self):
timestamp = int(time.time())
return ["%s@%s" % (self.key, timestamp - i) for i in range(self.window)]
return self.backend.format_key_variants(self.key, [str(timestamp - i) for i in range(self.window)])

def _acquire(self):
keys = self._get_keys()
Expand Down