33import asyncio
44import logging
55import uuid
6+ from collections import defaultdict
67
78from dask .utils import parse_timedelta
89
10+ from distributed .semaphore import Semaphore
911from distributed .utils import SyncMethodMixin , TimeoutError , log_errors , wait_for
1012from distributed .worker import get_client
1113
1214logger = logging .getLogger (__name__ )
1315
1416
1517class ConditionExtension :
16- """Scheduler extension for managing distributed Conditions"""
18+ """Scheduler extension for managing Condition variable notifications
19+
20+ This extension only handles wait/notify coordination.
21+ The underlying lock is a Semaphore managed by SemaphoreExtension.
22+ """
1723
1824 def __init__ (self , scheduler ):
1925 self .scheduler = scheduler
20- self ._locks = {} # {name: asyncio.Lock}
21- self ._lock_holders = {} # {name: client_id}
22- self ._waiters = {} # {name: {waiter_id: asyncio.Event}}
26+ # {condition_name: {waiter_id: asyncio.Event}}
27+ self ._waiters = defaultdict (dict )
2328
2429 self .scheduler .handlers .update (
2530 {
2631 "condition_wait" : self .wait ,
2732 "condition_notify" : self .notify ,
28- "condition_acquire" : self .acquire ,
29- "condition_release" : self .release ,
3033 "condition_notify_all" : self .notify_all ,
3134 }
3235 )
3336
34- def _get_lock (self , name ):
35- if name not in self ._locks :
36- self ._locks [name ] = asyncio .Lock ()
37- return self ._locks [name ]
38-
39- @log_errors
40- async def acquire (self , name = None , id = None ):
41- """Acquire the underlying lock"""
42- lock = self ._get_lock (name )
43- await lock .acquire ()
44- self ._lock_holders [name ] = id
45- return True
46-
47- @log_errors
48- async def release (self , name = None , id = None ):
49- """Release the underlying lock"""
50- if self ._lock_holders .get (name ) != id :
51- return False
52-
53- lock = self ._locks [name ]
54- lock .release ()
55- del self ._lock_holders [name ]
56-
57- # Cleanup if no waiters
58- if name not in self ._waiters or not self ._waiters [name ]:
59- del self ._locks [name ]
60-
61- return True
62-
6337 @log_errors
6438 async def wait (self , name = None , id = None , timeout = None ):
65- """Wait on condition"""
66- # Verify lock is held by this client
67- if self ._lock_holders .get (name ) != id :
68- raise RuntimeError ("wait() called without holding the lock" )
69-
70- lock = self ._locks [name ]
39+ """Wait to be notified
7140
41+ Caller must already hold the lock (Semaphore lease).
42+ This only manages the wait/notify Events.
43+ """
7244 # Create event for this waiter
73- if name not in self ._waiters :
74- self ._waiters [name ] = {}
7545 event = asyncio .Event ()
7646 self ._waiters [name ][id ] = event
7747
78- # Release lock
79- lock .release ()
80- del self ._lock_holders [name ]
81-
8248 # Wait on event
8349 future = event .wait ()
8450 if timeout is not None :
@@ -95,18 +61,11 @@ async def wait(self, name=None, id=None, timeout=None):
9561 if not self ._waiters [name ]:
9662 del self ._waiters [name ]
9763
98- # Reacquire lock
99- await lock .acquire ()
100- self ._lock_holders [name ] = id
101-
10264 return result
10365
10466 @log_errors
10567 def notify (self , name = None , n = 1 ):
10668 """Notify n waiters"""
107- if self ._lock_holders .get (name ) is None :
108- raise RuntimeError ("notify() called without holding the lock" )
109-
11069 waiters = self ._waiters .get (name , {})
11170 count = 0
11271 for event in list (waiters .values ())[:n ]:
@@ -117,9 +76,6 @@ def notify(self, name=None, n=1):
11776 @log_errors
11877 def notify_all (self , name = None ):
11978 """Notify all waiters"""
120- if self ._lock_holders .get (name ) is None :
121- raise RuntimeError ("notify_all() called without holding the lock" )
122-
12379 waiters = self ._waiters .get (name , {})
12480 for event in waiters .values ():
12581 event .set ()
@@ -129,8 +85,7 @@ def notify_all(self, name=None):
12985class Condition (SyncMethodMixin ):
13086 """Distributed Condition Variable
13187
132- Mimics asyncio.Condition API. Allows coordination between
133- distributed workers using wait/notify pattern.
88+ Combines a Semaphore (lock) with wait/notify coordination.
13489
13590 Parameters
13691 ----------
@@ -144,19 +99,20 @@ class Condition(SyncMethodMixin):
14499 >>> from distributed import Condition
145100 >>> condition = Condition('my-condition')
146101 >>> async with condition:
147- ... await condition.wait() # Wait for notification
102+ ... await condition.wait()
148103
149104 >>> # In another worker/client
150105 >>> condition = Condition('my-condition')
151106 >>> async with condition:
152- ... condition.notify() # Wake one waiter
107+ ... condition.notify()
153108 """
154109
155110 def __init__ (self , name = None , client = None ):
156- self ._client = client
157111 self .name = name or f"condition-{ uuid .uuid4 ().hex } "
158112 self .id = uuid .uuid4 ().hex
159- self ._locked = False
113+ # Use Semaphore(max_leases=1) as the underlying lock
114+ self ._lock = Semaphore (max_leases = 1 , name = f"{ self .name } -lock" )
115+ self ._client = client
160116
161117 @property
162118 def client (self ):
@@ -169,7 +125,7 @@ def client(self):
169125
170126 @property
171127 def loop (self ):
172- return self .client .loop if self . client else None
128+ return self ._lock .loop
173129
174130 def _verify_running (self ):
175131 if not self .client :
@@ -181,20 +137,12 @@ def _verify_running(self):
181137
182138 async def acquire (self ):
183139 """Acquire underlying lock"""
184- self ._verify_running ()
185- result = await self .client .scheduler .condition_acquire (
186- name = self .name , id = self .id
187- )
188- self ._locked = result
140+ result = await self ._lock .acquire ()
189141 return result
190142
191143 async def release (self ):
192144 """Release underlying lock"""
193- if not self ._locked :
194- raise RuntimeError ("Cannot release un-acquired lock" )
195- self ._verify_running ()
196- await self .client .scheduler .condition_release (name = self .name , id = self .id )
197- self ._locked = False
145+ await self ._lock .release ()
198146
199147 async def wait (self , timeout = None ):
200148 """Wait until notified
@@ -212,45 +160,38 @@ async def wait(self, timeout=None):
212160 bool
213161 True if notified, False if timeout occurred
214162 """
215- if not self ._locked :
163+ if not self ._lock . locked () :
216164 raise RuntimeError ("wait() called without holding the lock" )
217165
218166 self ._verify_running ()
219167 timeout = parse_timedelta (timeout )
220- result = await self .client .scheduler .condition_wait (
221- name = self .name , id = self .id , timeout = timeout
222- )
223- return result
224168
225- def notify ( self , n = 1 ):
226- """Wake up one or more waiters
169+ # Release lock
170+ await self . _lock . release ()
227171
228- Parameters
229- ----------
230- n : int, optional
231- Number of waiters to wake. Default is 1.
172+ # Wait for notification
173+ try :
174+ result = await self .client .scheduler .condition_wait (
175+ name = self .name , id = self .id , timeout = timeout
176+ )
177+ finally :
178+ # Reacquire lock
179+ await self ._lock .acquire ()
232180
233- Returns
234- -------
235- int
236- Number of waiters notified
237- """
238- if not self ._locked :
181+ return result
182+
183+ def notify (self , n = 1 ):
184+ """Wake up one or more waiters"""
185+ if not self ._lock .locked ():
239186 raise RuntimeError ("Cannot notify without holding the lock" )
240187 self ._verify_running ()
241188 return self .client .sync (
242189 self .client .scheduler .condition_notify , name = self .name , n = n
243190 )
244191
245192 def notify_all (self ):
246- """Wake up all waiters
247-
248- Returns
249- -------
250- int
251- Number of waiters notified
252- """
253- if not self ._locked :
193+ """Wake up all waiters"""
194+ if not self ._lock .locked ():
254195 raise RuntimeError ("Cannot notify without holding the lock" )
255196 self ._verify_running ()
256197 return self .client .sync (
@@ -259,7 +200,7 @@ def notify_all(self):
259200
260201 def locked (self ):
261202 """Return True if lock is held"""
262- return self ._locked
203+ return self ._lock . locked ()
263204
264205 async def __aenter__ (self ):
265206 await self .acquire ()
0 commit comments