33import asyncio
44import logging
55import uuid
6- from collections import defaultdict
76
8- from distributed .utils import SyncMethodMixin , log_errors
7+ from dask .utils import parse_timedelta
8+
9+ from distributed .utils import SyncMethodMixin , TimeoutError , log_errors , wait_for
910from distributed .worker import get_client
1011
1112logger = logging .getLogger (__name__ )
@@ -16,10 +17,9 @@ class ConditionExtension:
1617
1718 def __init__ (self , scheduler ):
1819 self .scheduler = scheduler
19- # {condition_name: asyncio.Condition}
20- self ._conditions = {}
21- # {condition_name: set of waiter_ids}
22- self ._waiters = defaultdict (set )
20+ self ._locks = {} # {name: asyncio.Lock}
21+ self ._lock_holders = {} # {name: client_id}
22+ self ._waiters = {} # {name: {waiter_id: asyncio.Event}}
2323
2424 self .scheduler .handlers .update (
2525 {
@@ -30,152 +30,171 @@ def __init__(self, scheduler):
3030 }
3131 )
3232
33- def _get_condition (self , name ):
34- if name not in self ._conditions :
35- self ._conditions [name ] = asyncio .Condition ()
36- return self ._conditions [name ]
33+ def _get_lock (self , name ):
34+ if name not in self ._locks :
35+ self ._locks [name ] = asyncio .Lock ()
36+ return self ._locks [name ]
3737
3838 @log_errors
3939 async def acquire (self , name = None , id = None ):
40- """Acquire the underlying lock"""
41- condition = self . _get_condition ( name )
42- await condition . acquire ()
40+ lock = self . _get_lock ( name )
41+ await lock . acquire ( )
42+ self . _lock_holders [ name ] = id
4343 return True
4444
4545 @log_errors
4646 async def release (self , name = None , id = None ):
47- """Release the underlying lock"""
48- if name not in self ._conditions :
47+ if self ._lock_holders .get (name ) != id :
4948 return False
50- condition = self ._conditions [name ]
51- condition .release ()
49+
50+ lock = self ._locks [name ]
51+ lock .release ()
52+ del self ._lock_holders [name ]
53+
54+ # Cleanup if no waiters
55+ if name not in self ._waiters or not self ._waiters [name ]:
56+ del self ._locks [name ]
57+
5258 return True
5359
5460 @log_errors
5561 async def wait (self , name = None , id = None , timeout = None ):
56- """Wait on condition"""
57- condition = self ._get_condition (name )
58- self ._waiters [name ].add (id )
62+ # Verify lock is held by this client
63+ if self ._lock_holders .get (name ) != id :
64+ raise RuntimeError ("wait() called without holding the lock" )
65+
66+ lock = self ._locks [name ]
67+
68+ # Create event for this waiter
69+ if name not in self ._waiters :
70+ self ._waiters [name ] = {}
71+ event = asyncio .Event ()
72+ self ._waiters [name ][id ] = event
73+
74+ # Release lock
75+ lock .release ()
76+ del self ._lock_holders [name ]
77+
78+ # Wait on event
79+ future = event .wait ()
80+ if timeout is not None :
81+ future = wait_for (future , timeout )
5982
6083 try :
61- if timeout :
62- await asyncio .wait_for (condition .wait (), timeout = timeout )
63- else :
64- await condition .wait ()
65- return True
66- except asyncio .TimeoutError :
67- return False
68- except asyncio .CancelledError :
69- raise
84+ await future
85+ result = True
86+ except TimeoutError :
87+ result = False
7088 finally :
71- self ._waiters [name ].discard (id )
89+ # Cleanup waiter
90+ self ._waiters [name ].pop (id , None )
7291 if not self ._waiters [name ]:
7392 del self ._waiters [name ]
7493
94+ # Reacquire lock
95+ await lock .acquire ()
96+ self ._lock_holders [name ] = id
97+
98+ return result
99+
75100 @log_errors
76101 def notify (self , name = None , n = 1 ):
77- """Notify n waiters"""
78- if name not in self ._conditions :
79- return 0
80- condition = self ._conditions [name ]
81- condition .notify (n = n )
82- return min (n , len (self ._waiters .get (name , [])))
102+ if self ._lock_holders .get (name ) is None :
103+ raise RuntimeError ("notify() called without holding the lock" )
104+
105+ waiters = self ._waiters .get (name , {})
106+ count = 0
107+ for event in list (waiters .values ())[:n ]:
108+ event .set ()
109+ count += 1
110+ return count
83111
84112 @log_errors
85113 def notify_all (self , name = None ):
86- """Notify all waiters"""
87- if name not in self . _conditions :
88- return 0
89- condition = self ._conditions [ name ]
90- count = len ( self . _waiters . get ( name , []))
91- condition . notify_all ()
92- return count
114+ if self . _lock_holders . get ( name ) is None :
115+ raise RuntimeError ( "notify_all() called without holding the lock" )
116+
117+ waiters = self ._waiters . get ( name , {})
118+ for event in waiters . values ():
119+ event . set ()
120+ return len ( waiters )
93121
94122
95123class Condition (SyncMethodMixin ):
96124 """Distributed Condition Variable
97125
98- Mimics asyncio.Condition API. Allows coordination between
99- distributed workers using wait/notify pattern.
126+ Parameters
127+ ----------
128+ name: str, optional
129+ Name of the condition. Same name = shared state.
130+ client: Client, optional
131+ Client for scheduler communication.
100132
101133 Examples
102134 --------
103- >>> from distributed import Condition
104135 >>> condition = Condition('my-condition')
105136 >>> async with condition:
106- ... await condition.wait() # Wait for notification
107-
108- >>> # In another worker/client
109- >>> condition = Condition('my-condition')
110- >>> async with condition:
111- ... condition.notify() # Wake one waiter
137+ ... await condition.wait()
112138 """
113139
114- def __init__ (self , name = None , scheduler_rpc = None , loop = None ):
115- self ._scheduler = scheduler_rpc
116- self ._loop = loop
140+ def __init__ (self , name = None , client = None ):
141+ self ._client = client
117142 self .name = name or f"condition-{ uuid .uuid4 ().hex } "
118143 self .id = uuid .uuid4 ().hex
119144 self ._locked = False
120145
121- def _get_scheduler_rpc ( self ):
122- if self . _scheduler :
123- return self ._scheduler
124- try :
125- client = get_client ()
126- return client . scheduler
127- except ValueError :
128- from distributed . worker import get_worker
146+ @ property
147+ def client ( self ) :
148+ if not self ._client :
149+ try :
150+ self . _client = get_client ()
151+ except ValueError :
152+ pass
153+ return self . _client
129154
130- worker = get_worker ()
131- return worker .scheduler
155+ def _verify_running (self ):
156+ if not self .client :
157+ raise RuntimeError (f"{ type (self )} object not properly initialized." )
132158
133159 async def acquire (self ):
134- """Acquire underlying lock"""
135- scheduler = self ._get_scheduler_rpc ()
136- result = await scheduler .condition_acquire (name = self .name , id = self .id )
160+ self ._verify_running ()
161+ result = await self .client .scheduler .condition_acquire (
162+ name = self .name , id = self .id
163+ )
137164 self ._locked = result
138165 return result
139166
140167 async def release (self ):
141- """Release underlying lock"""
142168 if not self ._locked :
143169 raise RuntimeError ("Cannot release un-acquired lock" )
144- scheduler = self ._get_scheduler_rpc ()
145- await scheduler .condition_release (name = self .name , id = self .id )
170+ self ._verify_running ()
171+ await self . client . scheduler .condition_release (name = self .name , id = self .id )
146172 self ._locked = False
147173
148174 async def wait (self , timeout = None ):
149- """Wait until notified
150-
151- Must be called while lock is held. Releases lock and waits
152- for notify(), then reacquires lock before returning.
153- """
154175 if not self ._locked :
155- raise RuntimeError ("Cannot wait on un-acquired condition " )
176+ raise RuntimeError ("wait() called without holding the lock " )
156177
157- scheduler = self ._get_scheduler_rpc ()
158- result = await scheduler .condition_wait (
178+ self ._verify_running ()
179+ timeout = parse_timedelta (timeout )
180+ result = await self .client .scheduler .condition_wait (
159181 name = self .name , id = self .id , timeout = timeout
160182 )
161183 return result
162184
163185 async def notify (self , n = 1 ):
164- """Wake up one or more waiters"""
165186 if not self ._locked :
166- raise RuntimeError ("Cannot notify on un-acquired condition " )
167- scheduler = self ._get_scheduler_rpc ()
168- return await scheduler .condition_notify (name = self .name , n = n )
187+ raise RuntimeError ("Cannot notify without holding the lock " )
188+ self ._verify_running ()
189+ return await self . client . scheduler .condition_notify (name = self .name , n = n )
169190
170191 async def notify_all (self ):
171- """Wake up all waiters"""
172192 if not self ._locked :
173- raise RuntimeError ("Cannot notify on un-acquired condition " )
174- scheduler = self ._get_scheduler_rpc ()
175- return await scheduler .condition_notify_all (name = self .name )
193+ raise RuntimeError ("Cannot notify without holding the lock " )
194+ self ._verify_running ()
195+ return await self . client . scheduler .condition_notify_all (name = self .name )
176196
177197 def locked (self ):
178- """Return True if lock is held"""
179198 return self ._locked
180199
181200 async def __aenter__ (self ):
@@ -193,3 +212,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
193212
194213 def __repr__ (self ):
195214 return f"<Condition: { self .name } >"
215+
216+ def __reduce__ (self ):
217+ return (Condition , (self .name ,))
0 commit comments