Skip to content

Commit 0de18c9

Browse files
committed
Update condition.py
1 parent fb29222 commit 0de18c9

File tree

1 file changed

+110
-88
lines changed

1 file changed

+110
-88
lines changed

distributed/condition.py

Lines changed: 110 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import asyncio
44
import logging
55
import uuid
6-
from collections import defaultdict
76

8-
from distributed.utils import SyncMethodMixin, log_errors
7+
from dask.utils import parse_timedelta
8+
9+
from distributed.utils import SyncMethodMixin, TimeoutError, log_errors, wait_for
910
from distributed.worker import get_client
1011

1112
logger = logging.getLogger(__name__)
@@ -16,10 +17,9 @@ class ConditionExtension:
1617

1718
def __init__(self, scheduler):
1819
self.scheduler = scheduler
19-
# {condition_name: asyncio.Condition}
20-
self._conditions = {}
21-
# {condition_name: set of waiter_ids}
22-
self._waiters = defaultdict(set)
20+
self._locks = {} # {name: asyncio.Lock}
21+
self._lock_holders = {} # {name: client_id}
22+
self._waiters = {} # {name: {waiter_id: asyncio.Event}}
2323

2424
self.scheduler.handlers.update(
2525
{
@@ -30,152 +30,171 @@ def __init__(self, scheduler):
3030
}
3131
)
3232

33-
def _get_condition(self, name):
34-
if name not in self._conditions:
35-
self._conditions[name] = asyncio.Condition()
36-
return self._conditions[name]
33+
def _get_lock(self, name):
34+
if name not in self._locks:
35+
self._locks[name] = asyncio.Lock()
36+
return self._locks[name]
3737

3838
@log_errors
3939
async def acquire(self, name=None, id=None):
40-
"""Acquire the underlying lock"""
41-
condition = self._get_condition(name)
42-
await condition.acquire()
40+
lock = self._get_lock(name)
41+
await lock.acquire()
42+
self._lock_holders[name] = id
4343
return True
4444

4545
@log_errors
4646
async def release(self, name=None, id=None):
47-
"""Release the underlying lock"""
48-
if name not in self._conditions:
47+
if self._lock_holders.get(name) != id:
4948
return False
50-
condition = self._conditions[name]
51-
condition.release()
49+
50+
lock = self._locks[name]
51+
lock.release()
52+
del self._lock_holders[name]
53+
54+
# Cleanup if no waiters
55+
if name not in self._waiters or not self._waiters[name]:
56+
del self._locks[name]
57+
5258
return True
5359

5460
@log_errors
5561
async def wait(self, name=None, id=None, timeout=None):
56-
"""Wait on condition"""
57-
condition = self._get_condition(name)
58-
self._waiters[name].add(id)
62+
# Verify lock is held by this client
63+
if self._lock_holders.get(name) != id:
64+
raise RuntimeError("wait() called without holding the lock")
65+
66+
lock = self._locks[name]
67+
68+
# Create event for this waiter
69+
if name not in self._waiters:
70+
self._waiters[name] = {}
71+
event = asyncio.Event()
72+
self._waiters[name][id] = event
73+
74+
# Release lock
75+
lock.release()
76+
del self._lock_holders[name]
77+
78+
# Wait on event
79+
future = event.wait()
80+
if timeout is not None:
81+
future = wait_for(future, timeout)
5982

6083
try:
61-
if timeout:
62-
await asyncio.wait_for(condition.wait(), timeout=timeout)
63-
else:
64-
await condition.wait()
65-
return True
66-
except asyncio.TimeoutError:
67-
return False
68-
except asyncio.CancelledError:
69-
raise
84+
await future
85+
result = True
86+
except TimeoutError:
87+
result = False
7088
finally:
71-
self._waiters[name].discard(id)
89+
# Cleanup waiter
90+
self._waiters[name].pop(id, None)
7291
if not self._waiters[name]:
7392
del self._waiters[name]
7493

94+
# Reacquire lock
95+
await lock.acquire()
96+
self._lock_holders[name] = id
97+
98+
return result
99+
75100
@log_errors
76101
def notify(self, name=None, n=1):
77-
"""Notify n waiters"""
78-
if name not in self._conditions:
79-
return 0
80-
condition = self._conditions[name]
81-
condition.notify(n=n)
82-
return min(n, len(self._waiters.get(name, [])))
102+
if self._lock_holders.get(name) is None:
103+
raise RuntimeError("notify() called without holding the lock")
104+
105+
waiters = self._waiters.get(name, {})
106+
count = 0
107+
for event in list(waiters.values())[:n]:
108+
event.set()
109+
count += 1
110+
return count
83111

84112
@log_errors
85113
def notify_all(self, name=None):
86-
"""Notify all waiters"""
87-
if name not in self._conditions:
88-
return 0
89-
condition = self._conditions[name]
90-
count = len(self._waiters.get(name, []))
91-
condition.notify_all()
92-
return count
114+
if self._lock_holders.get(name) is None:
115+
raise RuntimeError("notify_all() called without holding the lock")
116+
117+
waiters = self._waiters.get(name, {})
118+
for event in waiters.values():
119+
event.set()
120+
return len(waiters)
93121

94122

95123
class Condition(SyncMethodMixin):
96124
"""Distributed Condition Variable
97125
98-
Mimics asyncio.Condition API. Allows coordination between
99-
distributed workers using wait/notify pattern.
126+
Parameters
127+
----------
128+
name: str, optional
129+
Name of the condition. Same name = shared state.
130+
client: Client, optional
131+
Client for scheduler communication.
100132
101133
Examples
102134
--------
103-
>>> from distributed import Condition
104135
>>> condition = Condition('my-condition')
105136
>>> async with condition:
106-
... await condition.wait() # Wait for notification
107-
108-
>>> # In another worker/client
109-
>>> condition = Condition('my-condition')
110-
>>> async with condition:
111-
... condition.notify() # Wake one waiter
137+
... await condition.wait()
112138
"""
113139

114-
def __init__(self, name=None, scheduler_rpc=None, loop=None):
115-
self._scheduler = scheduler_rpc
116-
self._loop = loop
140+
def __init__(self, name=None, client=None):
141+
self._client = client
117142
self.name = name or f"condition-{uuid.uuid4().hex}"
118143
self.id = uuid.uuid4().hex
119144
self._locked = False
120145

121-
def _get_scheduler_rpc(self):
122-
if self._scheduler:
123-
return self._scheduler
124-
try:
125-
client = get_client()
126-
return client.scheduler
127-
except ValueError:
128-
from distributed.worker import get_worker
146+
@property
147+
def client(self):
148+
if not self._client:
149+
try:
150+
self._client = get_client()
151+
except ValueError:
152+
pass
153+
return self._client
129154

130-
worker = get_worker()
131-
return worker.scheduler
155+
def _verify_running(self):
156+
if not self.client:
157+
raise RuntimeError(f"{type(self)} object not properly initialized.")
132158

133159
async def acquire(self):
134-
"""Acquire underlying lock"""
135-
scheduler = self._get_scheduler_rpc()
136-
result = await scheduler.condition_acquire(name=self.name, id=self.id)
160+
self._verify_running()
161+
result = await self.client.scheduler.condition_acquire(
162+
name=self.name, id=self.id
163+
)
137164
self._locked = result
138165
return result
139166

140167
async def release(self):
141-
"""Release underlying lock"""
142168
if not self._locked:
143169
raise RuntimeError("Cannot release un-acquired lock")
144-
scheduler = self._get_scheduler_rpc()
145-
await scheduler.condition_release(name=self.name, id=self.id)
170+
self._verify_running()
171+
await self.client.scheduler.condition_release(name=self.name, id=self.id)
146172
self._locked = False
147173

148174
async def wait(self, timeout=None):
149-
"""Wait until notified
150-
151-
Must be called while lock is held. Releases lock and waits
152-
for notify(), then reacquires lock before returning.
153-
"""
154175
if not self._locked:
155-
raise RuntimeError("Cannot wait on un-acquired condition")
176+
raise RuntimeError("wait() called without holding the lock")
156177

157-
scheduler = self._get_scheduler_rpc()
158-
result = await scheduler.condition_wait(
178+
self._verify_running()
179+
timeout = parse_timedelta(timeout)
180+
result = await self.client.scheduler.condition_wait(
159181
name=self.name, id=self.id, timeout=timeout
160182
)
161183
return result
162184

163185
async def notify(self, n=1):
164-
"""Wake up one or more waiters"""
165186
if not self._locked:
166-
raise RuntimeError("Cannot notify on un-acquired condition")
167-
scheduler = self._get_scheduler_rpc()
168-
return await scheduler.condition_notify(name=self.name, n=n)
187+
raise RuntimeError("Cannot notify without holding the lock")
188+
self._verify_running()
189+
return await self.client.scheduler.condition_notify(name=self.name, n=n)
169190

170191
async def notify_all(self):
171-
"""Wake up all waiters"""
172192
if not self._locked:
173-
raise RuntimeError("Cannot notify on un-acquired condition")
174-
scheduler = self._get_scheduler_rpc()
175-
return await scheduler.condition_notify_all(name=self.name)
193+
raise RuntimeError("Cannot notify without holding the lock")
194+
self._verify_running()
195+
return await self.client.scheduler.condition_notify_all(name=self.name)
176196

177197
def locked(self):
178-
"""Return True if lock is held"""
179198
return self._locked
180199

181200
async def __aenter__(self):
@@ -193,3 +212,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
193212

194213
def __repr__(self):
195214
return f"<Condition: {self.name}>"
215+
216+
def __reduce__(self):
217+
return (Condition, (self.name,))

0 commit comments

Comments
 (0)