Skip to content

Commit db7a582

Browse files
committed
Update condition.py
1 parent b6b31a6 commit db7a582

File tree

1 file changed

+41
-100
lines changed

1 file changed

+41
-100
lines changed

distributed/condition.py

Lines changed: 41 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -3,82 +3,48 @@
33
import asyncio
44
import logging
55
import uuid
6+
from collections import defaultdict
67

78
from dask.utils import parse_timedelta
89

10+
from distributed.semaphore import Semaphore
911
from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for
1012
from distributed.worker import get_client
1113

1214
logger = logging.getLogger(__name__)
1315

1416

1517
class ConditionExtension:
16-
"""Scheduler extension for managing distributed Conditions"""
18+
"""Scheduler extension for managing Condition variable notifications
19+
20+
This extension only handles wait/notify coordination.
21+
The underlying lock is a Semaphore managed by SemaphoreExtension.
22+
"""
1723

1824
def __init__(self, scheduler):
1925
self.scheduler = scheduler
20-
self._locks = {} # {name: asyncio.Lock}
21-
self._lock_holders = {} # {name: client_id}
22-
self._waiters = {} # {name: {waiter_id: asyncio.Event}}
26+
# {condition_name: {waiter_id: asyncio.Event}}
27+
self._waiters = defaultdict(dict)
2328

2429
self.scheduler.handlers.update(
2530
{
2631
"condition_wait": self.wait,
2732
"condition_notify": self.notify,
28-
"condition_acquire": self.acquire,
29-
"condition_release": self.release,
3033
"condition_notify_all": self.notify_all,
3134
}
3235
)
3336

34-
def _get_lock(self, name):
35-
if name not in self._locks:
36-
self._locks[name] = asyncio.Lock()
37-
return self._locks[name]
38-
39-
@log_errors
40-
async def acquire(self, name=None, id=None):
41-
"""Acquire the underlying lock"""
42-
lock = self._get_lock(name)
43-
await lock.acquire()
44-
self._lock_holders[name] = id
45-
return True
46-
47-
@log_errors
48-
async def release(self, name=None, id=None):
49-
"""Release the underlying lock"""
50-
if self._lock_holders.get(name) != id:
51-
return False
52-
53-
lock = self._locks[name]
54-
lock.release()
55-
del self._lock_holders[name]
56-
57-
# Cleanup if no waiters
58-
if name not in self._waiters or not self._waiters[name]:
59-
del self._locks[name]
60-
61-
return True
62-
6337
@log_errors
6438
async def wait(self, name=None, id=None, timeout=None):
65-
"""Wait on condition"""
66-
# Verify lock is held by this client
67-
if self._lock_holders.get(name) != id:
68-
raise RuntimeError("wait() called without holding the lock")
69-
70-
lock = self._locks[name]
39+
"""Wait to be notified
7140
41+
Caller must already hold the lock (Semaphore lease).
42+
This only manages the wait/notify Events.
43+
"""
7244
# Create event for this waiter
73-
if name not in self._waiters:
74-
self._waiters[name] = {}
7545
event = asyncio.Event()
7646
self._waiters[name][id] = event
7747

78-
# Release lock
79-
lock.release()
80-
del self._lock_holders[name]
81-
8248
# Wait on event
8349
future = event.wait()
8450
if timeout is not None:
@@ -95,18 +61,11 @@ async def wait(self, name=None, id=None, timeout=None):
9561
if not self._waiters[name]:
9662
del self._waiters[name]
9763

98-
# Reacquire lock
99-
await lock.acquire()
100-
self._lock_holders[name] = id
101-
10264
return result
10365

10466
@log_errors
10567
def notify(self, name=None, n=1):
10668
"""Notify n waiters"""
107-
if self._lock_holders.get(name) is None:
108-
raise RuntimeError("notify() called without holding the lock")
109-
11069
waiters = self._waiters.get(name, {})
11170
count = 0
11271
for event in list(waiters.values())[:n]:
@@ -117,9 +76,6 @@ def notify(self, name=None, n=1):
11776
@log_errors
11877
def notify_all(self, name=None):
11978
"""Notify all waiters"""
120-
if self._lock_holders.get(name) is None:
121-
raise RuntimeError("notify_all() called without holding the lock")
122-
12379
waiters = self._waiters.get(name, {})
12480
for event in waiters.values():
12581
event.set()
@@ -129,8 +85,7 @@ def notify_all(self, name=None):
12985
class Condition(SyncMethodMixin):
13086
"""Distributed Condition Variable
13187
132-
Mimics asyncio.Condition API. Allows coordination between
133-
distributed workers using wait/notify pattern.
88+
Combines a Semaphore (lock) with wait/notify coordination.
13489
13590
Parameters
13691
----------
@@ -144,19 +99,20 @@ class Condition(SyncMethodMixin):
14499
>>> from distributed import Condition
145100
>>> condition = Condition('my-condition')
146101
>>> async with condition:
147-
... await condition.wait() # Wait for notification
102+
... await condition.wait()
148103
149104
>>> # In another worker/client
150105
>>> condition = Condition('my-condition')
151106
>>> async with condition:
152-
... condition.notify() # Wake one waiter
107+
... condition.notify()
153108
"""
154109

155110
def __init__(self, name=None, client=None):
156-
self._client = client
157111
self.name = name or f"condition-{uuid.uuid4().hex}"
158112
self.id = uuid.uuid4().hex
159-
self._locked = False
113+
# Use Semaphore(max_leases=1) as the underlying lock
114+
self._lock = Semaphore(max_leases=1, name=f"{self.name}-lock")
115+
self._client = client
160116

161117
@property
162118
def client(self):
@@ -169,7 +125,7 @@ def client(self):
169125

170126
@property
171127
def loop(self):
172-
return self.client.loop if self.client else None
128+
return self._lock.loop
173129

174130
def _verify_running(self):
175131
if not self.client:
@@ -181,20 +137,12 @@ def _verify_running(self):
181137

182138
async def acquire(self):
183139
"""Acquire underlying lock"""
184-
self._verify_running()
185-
result = await self.client.scheduler.condition_acquire(
186-
name=self.name, id=self.id
187-
)
188-
self._locked = result
140+
result = await self._lock.acquire()
189141
return result
190142

191143
async def release(self):
192144
"""Release underlying lock"""
193-
if not self._locked:
194-
raise RuntimeError("Cannot release un-acquired lock")
195-
self._verify_running()
196-
await self.client.scheduler.condition_release(name=self.name, id=self.id)
197-
self._locked = False
145+
await self._lock.release()
198146

199147
async def wait(self, timeout=None):
200148
"""Wait until notified
@@ -212,45 +160,38 @@ async def wait(self, timeout=None):
212160
bool
213161
True if notified, False if timeout occurred
214162
"""
215-
if not self._locked:
163+
if not self._lock.locked():
216164
raise RuntimeError("wait() called without holding the lock")
217165

218166
self._verify_running()
219167
timeout = parse_timedelta(timeout)
220-
result = await self.client.scheduler.condition_wait(
221-
name=self.name, id=self.id, timeout=timeout
222-
)
223-
return result
224168

225-
def notify(self, n=1):
226-
"""Wake up one or more waiters
169+
# Release lock
170+
await self._lock.release()
227171

228-
Parameters
229-
----------
230-
n : int, optional
231-
Number of waiters to wake. Default is 1.
172+
# Wait for notification
173+
try:
174+
result = await self.client.scheduler.condition_wait(
175+
name=self.name, id=self.id, timeout=timeout
176+
)
177+
finally:
178+
# Reacquire lock
179+
await self._lock.acquire()
232180

233-
Returns
234-
-------
235-
int
236-
Number of waiters notified
237-
"""
238-
if not self._locked:
181+
return result
182+
183+
def notify(self, n=1):
184+
"""Wake up one or more waiters"""
185+
if not self._lock.locked():
239186
raise RuntimeError("Cannot notify without holding the lock")
240187
self._verify_running()
241188
return self.client.sync(
242189
self.client.scheduler.condition_notify, name=self.name, n=n
243190
)
244191

245192
def notify_all(self):
246-
"""Wake up all waiters
247-
248-
Returns
249-
-------
250-
int
251-
Number of waiters notified
252-
"""
253-
if not self._locked:
193+
"""Wake up all waiters"""
194+
if not self._lock.locked():
254195
raise RuntimeError("Cannot notify without holding the lock")
255196
self._verify_running()
256197
return self.client.sync(
@@ -259,7 +200,7 @@ def notify_all(self):
259200

260201
def locked(self):
261202
"""Return True if lock is held"""
262-
return self._locked
203+
return self._lock.locked()
263204

264205
async def __aenter__(self):
265206
await self.acquire()

0 commit comments

Comments
 (0)