@@ -396,17 +396,20 @@ def patched_eval(
396396 sha1 = hashlib .sha1 (script ).hexdigest ().encode ()
397397 self ._server .script_cache [sha1 ] = script
398398
399- # Cache LuaRuntime and set_globals function on the server
399+ # Cache LuaRuntime and all callbacks on the server
400400 if not hasattr (self ._server , "_lua_runtime" ):
401401 self ._server ._lua_runtime = LUA_MODULE .LuaRuntime (
402402 encoding = None , unpack_returned_tuples = True
403403 )
404+ lua_runtime = self ._server ._lua_runtime
404405 modules_import_str = "\n " .join (
405406 [f"{ module } = require('{ module } ')" for module in self .load_lua_modules ]
406407 )
407- self ._server ._lua_set_globals = self ._server ._lua_runtime .eval (
408+
409+ # Create set_globals for initial setup (sets callbacks once)
410+ set_globals_init = lua_runtime .eval (
408411 f"""
409- function(keys, argv, redis_call, redis_pcall, redis_log, cjson_encode, cjson_decode, cjson_null)
412+ function(redis_call, redis_pcall, redis_log, cjson_encode, cjson_decode, cjson_null)
410413 redis = {{}}
411414 redis.call = redis_call
412415 redis.pcall = redis_pcall
@@ -423,40 +426,94 @@ def patched_eval(
423426 cjson.decode = cjson_decode
424427 cjson.null = cjson_null
425428
429+ KEYS = {{}}
430+ ARGV = {{}}
431+ { modules_import_str }
432+ end
433+ """
434+ )
435+
436+ # Create set_keys_argv to update just KEYS/ARGV per call
437+ self ._server ._lua_set_keys_argv = lua_runtime .eval (
438+ """
439+ function(keys, argv)
426440 KEYS = keys
427441 ARGV = argv
428- { modules_import_str }
429442 end
430443 """
431444 )
432- # Capture expected globals once after first setup
433- self ._server ._lua_set_globals (
434- self ._server ._lua_runtime .table_from ([]),
435- self ._server ._lua_runtime .table_from ([]),
445+
446+ # Capture expected globals before setting up callbacks
447+ set_globals_init (
436448 lambda * args : None ,
437449 lambda * args : None ,
438450 lambda * args : None ,
439451 lambda * args : None ,
440452 lambda * args : None ,
441453 None ,
442454 )
443- self ._server ._lua_expected_globals = set (
444- self ._server ._lua_runtime .globals ().keys ()
455+ self ._server ._lua_expected_globals = set (lua_runtime .globals ().keys ())
456+ expected_globals = self ._server ._lua_expected_globals
457+
458+ # Container to hold current socket - callbacks will look this up
459+ self ._server ._lua_current_socket = [None ] # Use list for mutability
460+
461+ # Create wrapper callbacks that look up the current socket dynamically
462+ def make_redis_call_wrapper () -> typing .Callable [..., typing .Any ]:
463+ server = self ._server
464+ lr = lua_runtime
465+ eg = expected_globals
466+
467+ def wrapper (op : bytes , * args : typing .Any ) -> typing .Any :
468+ socket = server ._lua_current_socket [0 ]
469+ return socket ._lua_redis_call (lr , eg , op , * args )
470+
471+ return wrapper
472+
473+ def make_redis_pcall_wrapper () -> typing .Callable [..., typing .Any ]:
474+ server = self ._server
475+ lr = lua_runtime
476+ eg = expected_globals
477+
478+ def wrapper (op : bytes , * args : typing .Any ) -> typing .Any :
479+ socket = server ._lua_current_socket [0 ]
480+ return socket ._lua_redis_pcall (lr , eg , op , * args )
481+
482+ return wrapper
483+
484+ # Cache the callback wrappers and static partials
485+ self ._server ._lua_redis_call_wrapper = make_redis_call_wrapper ()
486+ self ._server ._lua_redis_pcall_wrapper = make_redis_pcall_wrapper ()
487+ self ._server ._lua_log_partial = functools .partial (
488+ _lua_redis_log , lua_runtime , expected_globals
489+ )
490+ self ._server ._lua_cjson_encode_partial = functools .partial (
491+ _lua_cjson_encode , lua_runtime , expected_globals
492+ )
493+ self ._server ._lua_cjson_decode_partial = functools .partial (
494+ _lua_cjson_decode , lua_runtime , expected_globals
495+ )
496+
497+ # Set up all callbacks once
498+ set_globals_init (
499+ self ._server ._lua_redis_call_wrapper ,
500+ self ._server ._lua_redis_pcall_wrapper ,
501+ self ._server ._lua_log_partial ,
502+ self ._server ._lua_cjson_encode_partial ,
503+ self ._server ._lua_cjson_decode_partial ,
504+ _lua_cjson_null ,
445505 )
446506
447507 lua_runtime = self ._server ._lua_runtime
448- set_globals = self ._server ._lua_set_globals
449508 expected_globals = self ._server ._lua_expected_globals
450509
451- set_globals (
510+ # Update the current socket so callbacks can find it
511+ self ._server ._lua_current_socket [0 ] = self
512+
513+ # Only update KEYS and ARGV per call (callbacks are already set)
514+ self ._server ._lua_set_keys_argv (
452515 lua_runtime .table_from (keys_and_args [:numkeys ]),
453516 lua_runtime .table_from (keys_and_args [numkeys :]),
454- functools .partial (self ._lua_redis_call , lua_runtime , expected_globals ),
455- functools .partial (self ._lua_redis_pcall , lua_runtime , expected_globals ),
456- functools .partial (_lua_redis_log , lua_runtime , expected_globals ),
457- functools .partial (_lua_cjson_encode , lua_runtime , expected_globals ),
458- functools .partial (_lua_cjson_decode , lua_runtime , expected_globals ),
459- _lua_cjson_null ,
460517 )
461518
462519 try :
@@ -479,6 +536,9 @@ def patched_eval(
479536
480537 _check_for_lua_globals (lua_runtime , expected_globals )
481538
539+ # Clean up Lua tables (KEYS/ARGV) created for this script execution
540+ lua_runtime .execute ("collectgarbage()" )
541+
482542 return self ._convert_lua_result (result , nested = False )
483543
484544 ScriptingCommandsMixin .eval = patched_eval
0 commit comments