From 0c48baa042bb198474033934fda788285827666e Mon Sep 17 00:00:00 2001 From: Mikhail Bulash Date: Wed, 6 Aug 2025 16:38:38 +0200 Subject: [PATCH 1/2] Rewrite Redis RateLimiterBackend with Lua scripts --- dramatiq/rate_limits/backends/redis.py | 71 +++++-------------- .../backends/redis/decr_down_to.lua | 19 +++++ .../rate_limits/backends/redis/incr_up_to.lua | 19 +++++ .../redis/incr_up_to_with_sum_check.lua | 43 +++++++++++ 4 files changed, 99 insertions(+), 53 deletions(-) create mode 100644 dramatiq/rate_limits/backends/redis/decr_down_to.lua create mode 100644 dramatiq/rate_limits/backends/redis/incr_up_to.lua create mode 100644 dramatiq/rate_limits/backends/redis/incr_up_to_with_sum_check.lua diff --git a/dramatiq/rate_limits/backends/redis.py b/dramatiq/rate_limits/backends/redis.py index aa64a95b..f2993124 100644 --- a/dramatiq/rate_limits/backends/redis.py +++ b/dramatiq/rate_limits/backends/redis.py @@ -17,10 +17,17 @@ from __future__ import annotations +from pathlib import Path + import redis from ..backend import RateLimiterBackend +_SCRIPTS = { + path.stem: path.read_text() + for path in (Path(__file__).parent / "redis").glob("*.lua") +} + class RedisBackend(RateLimiterBackend): """A rate limiter backend for Redis_. @@ -41,68 +48,26 @@ def __init__(self, *, client=None, url=None, **parameters): parameters["connection_pool"] = redis.ConnectionPool.from_url(url) self.client = client or redis.Redis(**parameters) + self.scripts = { + name: self.client.register_script(text) for name, text in _SCRIPTS.items() + } def add(self, key, value, ttl): return bool(self.client.set(key, value, px=ttl, nx=True)) def incr(self, key, amount, maximum, ttl): - with self.client.pipeline() as pipe: - while True: - try: - pipe.watch(key) - value = int(pipe.get(key) or b"0") - value += amount - if value > maximum: - return False - - pipe.multi() - pipe.set(key, value, px=ttl) - pipe.execute() - return True - except redis.WatchError: - continue + incr_up_to = self.scripts["incr_up_to"] + return incr_up_to([key], [amount, maximum, ttl]) == 1 def decr(self, key, amount, minimum, ttl): - with self.client.pipeline() as pipe: - while True: - try: - pipe.watch(key) - value = int(pipe.get(key) or b"0") - value -= amount - if value < minimum: - return False - - pipe.multi() - pipe.set(key, value, px=ttl) - pipe.execute() - return True - except redis.WatchError: - continue + decr_down_to = self.scripts["decr_down_to"] + return decr_down_to([key], [amount, minimum, ttl]) == 1 def incr_and_sum(self, key, keys, amount, maximum, ttl): - with self.client.pipeline() as pipe: - while True: - try: - # TODO: Drop non-callable keys in Dramatiq v2. - key_list = keys() if callable(keys) else keys - pipe.watch(key, *key_list) - value = int(pipe.get(key) or b"0") - value += amount - if value > maximum: - return False - - # Fetch keys again to account for net/server latency. - values = pipe.mget(keys() if callable(keys) else keys) - total = amount + sum(int(n) for n in values if n) - if total > maximum: - return False - - pipe.multi() - pipe.set(key, value, px=ttl) - pipe.execute() - return True - except redis.WatchError: - continue + # TODO: Drop non-callable keys in Dramatiq v2. + keys_list = keys() if callable(keys) else keys + incr_up_to_with_sum_check = self.scripts["incr_up_to_with_sum_check"] + return incr_up_to_with_sum_check([key, *keys_list], [amount, maximum, ttl]) == 1 def wait(self, key, timeout): assert timeout is None or timeout >= 1000, "wait timeouts must be >= 1000" diff --git a/dramatiq/rate_limits/backends/redis/decr_down_to.lua b/dramatiq/rate_limits/backends/redis/decr_down_to.lua new file mode 100644 index 00000000..768e25b6 --- /dev/null +++ b/dramatiq/rate_limits/backends/redis/decr_down_to.lua @@ -0,0 +1,19 @@ +-- decr_down_to( +-- keys=[key] +-- args=[amount, minimum, ttl] +-- ) +-- +-- Decrement `key` by `amount`, unless the resulting value is less than `minimum`. + +local key = KEYS[1] +local amount = tonumber(ARGV[1]) +local minimum = tonumber(ARGV[2]) +local ttl = tonumber(ARGV[3]) + +local value = redis.call('GET', key) or 0 +if value - amount < minimum then + return false +end + +redis.call('SET', key, value - amount, 'PX', ttl) +return true diff --git a/dramatiq/rate_limits/backends/redis/incr_up_to.lua b/dramatiq/rate_limits/backends/redis/incr_up_to.lua new file mode 100644 index 00000000..3984a316 --- /dev/null +++ b/dramatiq/rate_limits/backends/redis/incr_up_to.lua @@ -0,0 +1,19 @@ +-- incr_up_to( +-- keys=[key] +-- args=[amount, maximum, ttl] +-- ) +-- +-- Increment `key` by `amount`, unless the resulting value is greater than `maximum`. + +local key = KEYS[1] +local amount = tonumber(ARGV[1]) +local maximum = tonumber(ARGV[2]) +local ttl = tonumber(ARGV[3]) + +local value = redis.call('GET', key) or 0 +if value + amount > maximum then + return false +end + +redis.call('SET', key, value + amount, 'PX', ttl) +return true diff --git a/dramatiq/rate_limits/backends/redis/incr_up_to_with_sum_check.lua b/dramatiq/rate_limits/backends/redis/incr_up_to_with_sum_check.lua new file mode 100644 index 00000000..e4b0c8d6 --- /dev/null +++ b/dramatiq/rate_limits/backends/redis/incr_up_to_with_sum_check.lua @@ -0,0 +1,43 @@ +-- incr_up_to_with_sum_check( +-- keys=[key, *keys] +-- args=[amount, maximum, ttl] +-- ) +-- +-- Atomically increment `key` by `amount`, unless: +-- - the incremented value is greater than `maximum`, or +-- - the incremented sum of `keys` is greater than `maximum`. + +-- split the key list into the first - `key` and the rest - `keys` +local key = KEYS[1] +local keys = {} +for i, k in ipairs(KEYS) do + if i > 1 then + keys[i - 1] = KEYS[i] + end +end + +local amount = tonumber(ARGV[1]) +local maximum = tonumber(ARGV[2]) +local ttl = tonumber(ARGV[3]) + +-- check if `key` can be incremented, bail if not +local value = redis.call('GET', key) or 0 +if value + amount > maximum then + return false +end + +-- check if sum of `keys` can be incremented, bail if not +local values = redis.call('MGET', unpack(keys)) +local sum = 0 +for _, v in ipairs(values) do + if v then + sum = sum + tonumber(v) + end +end +if sum + amount > maximum then + return false +end + +-- increment `key` if we got this far +redis.call('SET', key, value + amount, 'PX', ttl) +return true From b7513ccdeb42c641cb014e837b40d09791db5037 Mon Sep 17 00:00:00 2001 From: Mikhail Bulash Date: Wed, 20 Aug 2025 15:58:34 +0200 Subject: [PATCH 2/2] delegate time bucket naming to the backend; ensure same-slot hashing --- dramatiq/rate_limits/backend.py | 13 +++++++++++++ dramatiq/rate_limits/backends/redis.py | 7 +++++++ dramatiq/rate_limits/window.py | 2 +- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/dramatiq/rate_limits/backend.py b/dramatiq/rate_limits/backend.py index 63b1fead..1c0fab45 100644 --- a/dramatiq/rate_limits/backend.py +++ b/dramatiq/rate_limits/backend.py @@ -65,6 +65,19 @@ def decr(self, key, amount, minimum, ttl): # pragma: no cover """ raise NotImplementedError + def format_key_variants(self, key, variants): # pragma: no cover + """Build a list of related key names from a "base" key name and and a list of + distinct "variants" that can act as e.g. name suffixes. + + Parameters: + key(str): The base key name. + variants(list[str]): Distinct values to incorporate into the resulting names. + + Returns: + list[str]: The list of resulting key names. + """ + return [f"{key}@{variant}" for variant in variants] + def incr_and_sum(self, key, keys, amount, maximum, ttl): # pragma: no cover """Atomically increment a key unless the sum of keys is greater than the given maximum. diff --git a/dramatiq/rate_limits/backends/redis.py b/dramatiq/rate_limits/backends/redis.py index f2993124..a325e7e9 100644 --- a/dramatiq/rate_limits/backends/redis.py +++ b/dramatiq/rate_limits/backends/redis.py @@ -63,6 +63,13 @@ def decr(self, key, amount, minimum, ttl): decr_down_to = self.scripts["decr_down_to"] return decr_down_to([key], [amount, minimum, ttl]) == 1 + def format_key_variants(self, key, variants): + # NOTE the extra { } around the key - this is to make use of Hash tags [0] + # to make sure that multi-key commands execute on the keys in same hash slots. + # This helps avoid a ClusterCrossSlotError in case this redis is a Redis Cluster. + # [0]: https://redis.io/docs/latest/operate/oss_and_stack/reference/cluster-spec/#keys-hash-tags + return [f"{{{key}}}@{variant}" for variant in variants] + def incr_and_sum(self, key, keys, amount, maximum, ttl): # TODO: Drop non-callable keys in Dramatiq v2. keys_list = keys() if callable(keys) else keys diff --git a/dramatiq/rate_limits/window.py b/dramatiq/rate_limits/window.py index af159d04..3d8df647 100644 --- a/dramatiq/rate_limits/window.py +++ b/dramatiq/rate_limits/window.py @@ -52,7 +52,7 @@ def __init__(self, backend, key, *, limit=1, window=1): def _get_keys(self): timestamp = int(time.time()) - return ["%s@%s" % (self.key, timestamp - i) for i in range(self.window)] + return self.backend.format_key_variants(self.key, [str(timestamp - i) for i in range(self.window)]) def _acquire(self): keys = self._get_keys()