@@ -17,9 +17,13 @@ class Lock:
1717 multiple clients play nicely together.
1818 """
1919
20- lua_release = None
21- lua_extend = None
22- lua_reacquire = None
20+ # This function is used to flag that the class `lua_*` functions are not set yet.
21+ # Using a function, rather than `None`, prevents type annotation issues.
22+ def __undefined (* _ , ** __ ) -> None : ...
23+
24+ lua_release = __undefined
25+ lua_extend = __undefined
26+ lua_reacquire = __undefined
2327
2428 # KEYS[1] - lock name
2529 # ARGV[1] - token
@@ -147,11 +151,11 @@ def __init__(
147151 def register_scripts (self ) -> None :
148152 cls = self .__class__
149153 client = self .valkey
150- if cls .lua_release is None :
154+ if cls .lua_release is cls . __undefined :
151155 cls .lua_release = client .register_script (cls .LUA_RELEASE_SCRIPT )
152- if cls .lua_extend is None :
156+ if cls .lua_extend is cls . __undefined :
153157 cls .lua_extend = client .register_script (cls .LUA_EXTEND_SCRIPT )
154- if cls .lua_reacquire is None :
158+ if cls .lua_reacquire is cls . __undefined :
155159 cls .lua_reacquire = client .register_script (cls .LUA_REACQUIRE_SCRIPT )
156160
157161 def __enter__ (self ) -> "Lock" :
@@ -175,7 +179,7 @@ def acquire(
175179 sleep : Optional [Number ] = None ,
176180 blocking : Optional [bool ] = None ,
177181 blocking_timeout : Optional [Number ] = None ,
178- token : Optional [ str ] = None ,
182+ token : str | bytes | None = None ,
179183 ):
180184 """
181185 Use Valkey to hold a shared, distributed lock named ``name``.
@@ -195,10 +199,10 @@ def acquire(
195199 if sleep is None :
196200 sleep = self .sleep
197201 if token is None :
198- token = uuid .uuid1 ().hex .encode ()
202+ encoded_token = uuid .uuid1 ().hex .encode ()
199203 else :
200204 encoder = self .valkey .get_encoder ()
201- token = encoder .encode (token )
205+ encoded_token = encoder .encode (token )
202206 if blocking is None :
203207 blocking = self .blocking
204208 if blocking_timeout is None :
@@ -207,8 +211,8 @@ def acquire(
207211 if blocking_timeout is not None :
208212 stop_trying_at = mod_time .monotonic () + blocking_timeout
209213 while True :
210- if self .do_acquire (token ):
211- self .local .token = token
214+ if self .do_acquire (encoded_token ):
215+ self .local .token = encoded_token
212216 return True
213217 if not blocking :
214218 return False
@@ -217,7 +221,7 @@ def acquire(
217221 return False
218222 mod_time .sleep (sleep )
219223
220- def do_acquire (self , token : str ) -> bool :
224+ def do_acquire (self , token : bytes ) -> bool :
221225 if self .timeout :
222226 # convert to milliseconds
223227 timeout = int (self .timeout * 1000 )
@@ -312,6 +316,10 @@ def reacquire(self) -> bool:
312316 return self .do_reacquire ()
313317
314318 def do_reacquire (self ) -> bool :
319+ # `do_reacquire()` will only be called if `self.timeout` is not `None`.
320+ # However, this `assert` is needed so that mypy understands the type.
321+ assert self .timeout is not None
322+
315323 timeout = int (self .timeout * 1000 )
316324 if not bool (
317325 self .lua_reacquire (
0 commit comments