Skip to content

Commit c62a121

Browse files
chrisguidryclaude
andauthored
Fix memory leak in fakeredis Lua script execution (#300)
## Summary - Restructure Lua callback caching in the fakeredis monkeypatch to avoid creating new `functools.partial` objects on every `eval()` call - Add `collectgarbage()` after each script execution to clean up Lua tables - Fix flaky logging test that was broken by newer fakeredis DEBUG logging ## Problem The `memory://` backend was leaking memory when executing Lua scripts. Every `eval()` created new partial objects for callbacks that were held by the Lua runtime, preventing garbage collection. Related: PrefectHQ/prefect#18605 ## Testing Created a test harness running a perpetual task at ~10 evals/second: - **Before fix**: Memory grew several MB per minute - **After fix**: RSS stays flat at ~200KB above baseline after 1000+ evals The same fix has been submitted upstream: cunla/fakeredis-py#452 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 0062982 commit c62a121

File tree

3 files changed

+89
-25
lines changed

3 files changed

+89
-25
lines changed

src/docket/_redis.py

Lines changed: 78 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -396,17 +396,20 @@ def patched_eval(
396396
sha1 = hashlib.sha1(script).hexdigest().encode()
397397
self._server.script_cache[sha1] = script
398398

399-
# Cache LuaRuntime and set_globals function on the server
399+
# Cache LuaRuntime and all callbacks on the server
400400
if not hasattr(self._server, "_lua_runtime"):
401401
self._server._lua_runtime = LUA_MODULE.LuaRuntime(
402402
encoding=None, unpack_returned_tuples=True
403403
)
404+
lua_runtime = self._server._lua_runtime
404405
modules_import_str = "\n".join(
405406
[f"{module} = require('{module}')" for module in self.load_lua_modules]
406407
)
407-
self._server._lua_set_globals = self._server._lua_runtime.eval(
408+
409+
# Create set_globals for initial setup (sets callbacks once)
410+
set_globals_init = lua_runtime.eval(
408411
f"""
409-
function(keys, argv, redis_call, redis_pcall, redis_log, cjson_encode, cjson_decode, cjson_null)
412+
function(redis_call, redis_pcall, redis_log, cjson_encode, cjson_decode, cjson_null)
410413
redis = {{}}
411414
redis.call = redis_call
412415
redis.pcall = redis_pcall
@@ -423,40 +426,94 @@ def patched_eval(
423426
cjson.decode = cjson_decode
424427
cjson.null = cjson_null
425428
429+
KEYS = {{}}
430+
ARGV = {{}}
431+
{modules_import_str}
432+
end
433+
"""
434+
)
435+
436+
# Create set_keys_argv to update just KEYS/ARGV per call
437+
self._server._lua_set_keys_argv = lua_runtime.eval(
438+
"""
439+
function(keys, argv)
426440
KEYS = keys
427441
ARGV = argv
428-
{modules_import_str}
429442
end
430443
"""
431444
)
432-
# Capture expected globals once after first setup
433-
self._server._lua_set_globals(
434-
self._server._lua_runtime.table_from([]),
435-
self._server._lua_runtime.table_from([]),
445+
446+
# Capture expected globals before setting up callbacks
447+
set_globals_init(
436448
lambda *args: None,
437449
lambda *args: None,
438450
lambda *args: None,
439451
lambda *args: None,
440452
lambda *args: None,
441453
None,
442454
)
443-
self._server._lua_expected_globals = set(
444-
self._server._lua_runtime.globals().keys()
455+
self._server._lua_expected_globals = set(lua_runtime.globals().keys())
456+
expected_globals = self._server._lua_expected_globals
457+
458+
# Container to hold current socket - callbacks will look this up
459+
self._server._lua_current_socket = [None] # Use list for mutability
460+
461+
# Create wrapper callbacks that look up the current socket dynamically
462+
def make_redis_call_wrapper() -> typing.Callable[..., typing.Any]:
463+
server = self._server
464+
lr = lua_runtime
465+
eg = expected_globals
466+
467+
def wrapper(op: bytes, *args: typing.Any) -> typing.Any:
468+
socket = server._lua_current_socket[0]
469+
return socket._lua_redis_call(lr, eg, op, *args)
470+
471+
return wrapper
472+
473+
def make_redis_pcall_wrapper() -> typing.Callable[..., typing.Any]:
474+
server = self._server
475+
lr = lua_runtime
476+
eg = expected_globals
477+
478+
def wrapper(op: bytes, *args: typing.Any) -> typing.Any:
479+
socket = server._lua_current_socket[0]
480+
return socket._lua_redis_pcall(lr, eg, op, *args)
481+
482+
return wrapper
483+
484+
# Cache the callback wrappers and static partials
485+
self._server._lua_redis_call_wrapper = make_redis_call_wrapper()
486+
self._server._lua_redis_pcall_wrapper = make_redis_pcall_wrapper()
487+
self._server._lua_log_partial = functools.partial(
488+
_lua_redis_log, lua_runtime, expected_globals
489+
)
490+
self._server._lua_cjson_encode_partial = functools.partial(
491+
_lua_cjson_encode, lua_runtime, expected_globals
492+
)
493+
self._server._lua_cjson_decode_partial = functools.partial(
494+
_lua_cjson_decode, lua_runtime, expected_globals
495+
)
496+
497+
# Set up all callbacks once
498+
set_globals_init(
499+
self._server._lua_redis_call_wrapper,
500+
self._server._lua_redis_pcall_wrapper,
501+
self._server._lua_log_partial,
502+
self._server._lua_cjson_encode_partial,
503+
self._server._lua_cjson_decode_partial,
504+
_lua_cjson_null,
445505
)
446506

447507
lua_runtime = self._server._lua_runtime
448-
set_globals = self._server._lua_set_globals
449508
expected_globals = self._server._lua_expected_globals
450509

451-
set_globals(
510+
# Update the current socket so callbacks can find it
511+
self._server._lua_current_socket[0] = self
512+
513+
# Only update KEYS and ARGV per call (callbacks are already set)
514+
self._server._lua_set_keys_argv(
452515
lua_runtime.table_from(keys_and_args[:numkeys]),
453516
lua_runtime.table_from(keys_and_args[numkeys:]),
454-
functools.partial(self._lua_redis_call, lua_runtime, expected_globals),
455-
functools.partial(self._lua_redis_pcall, lua_runtime, expected_globals),
456-
functools.partial(_lua_redis_log, lua_runtime, expected_globals),
457-
functools.partial(_lua_cjson_encode, lua_runtime, expected_globals),
458-
functools.partial(_lua_cjson_decode, lua_runtime, expected_globals),
459-
_lua_cjson_null,
460517
)
461518

462519
try:
@@ -479,6 +536,9 @@ def patched_eval(
479536

480537
_check_for_lua_globals(lua_runtime, expected_globals)
481538

539+
# Clean up Lua tables (KEYS/ARGV) created for this script execution
540+
lua_runtime.execute("collectgarbage()")
541+
482542
return self._convert_lua_result(result, nested=False)
483543

484544
ScriptingCommandsMixin.eval = patched_eval

tests/fundamentals/test_logging.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@ async def the_task(
2828
with caplog.at_level(logging.INFO):
2929
await worker.run_until_finished()
3030

31-
assert "the_task('value-a', b=..., c='value-c', d=...)" in caplog.text
32-
assert "value-b" not in caplog.text
33-
assert "value-d" not in caplog.text
31+
# Filter to only docket logs (exclude fakeredis DEBUG logs which contain raw pickle data)
32+
docket_logs = "\n".join(
33+
r.message for r in caplog.records if r.name.startswith("docket")
34+
)
35+
assert "the_task('value-a', b=..., c='value-c', d=...)" in docket_logs
36+
assert "value-b" not in docket_logs
37+
assert "value-d" not in docket_logs
3438

3539

3640
async def test_tasks_can_opt_into_logging_collection_lengths(

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)