Skip to content

Commit ba4a9ce

Browse files
committed
Ensure actors set erred state properly in case of worker failure
1 parent 358402d commit ba4a9ce

File tree

3 files changed

+115
-7
lines changed

3 files changed

+115
-7
lines changed

distributed/actor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __dir__(self):
141141

142142
def __getattr__(self, key):
143143
if self._future and self._future.status not in ("finished", "pending"):
144-
raise ValueError(
144+
raise RuntimeError(
145145
"Worker holding Actor was lost. Status: " + self._future.status
146146
)
147147
self._try_bind_worker_client()

distributed/scheduler.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,7 +2599,7 @@ def _transition_memory_released(self, key: Key, stimulus_id: str) -> RecsMsgs:
25992599
if ts.who_wants:
26002600
ts.exception_blame = ts
26012601
ts.exception = Serialized(
2602-
*serialize(ValueError("Worker holding Actor was lost"))
2602+
*serialize(RuntimeError("Worker holding Actor was lost"))
26032603
)
26042604
return {ts.key: "erred"}, {}, {} # don't try to recreate
26052605

@@ -2652,7 +2652,7 @@ def _transition_released_erred(self, key: Key, stimulus_id: str) -> RecsMsgs:
26522652

26532653
if self.validate:
26542654
assert ts.exception_blame
2655-
assert not ts.who_has
2655+
assert not ts.who_has or ts.actor
26562656
assert not ts.waiting_on
26572657

26582658
failing_ts = ts.exception_blame
@@ -2772,8 +2772,8 @@ def _transition_processing_erred(
27722772
self,
27732773
key: Key,
27742774
stimulus_id: str,
2775-
worker: str,
27762775
*,
2776+
worker: str | None = None,
27772777
cause: Key | None = None,
27782778
exception: Serialized | None = None,
27792779
traceback: Serialized | None = None,
@@ -2988,6 +2988,45 @@ def _remove_key(self, key: Key) -> None:
29882988
ts.exception_blame = ts.exception = ts.traceback = None
29892989
self.task_metadata.pop(key, None)
29902990

2991+
def _transition_memory_erred(self, key: Key, stimulus_id: str) -> RecsMsgs:
2992+
ts = self.tasks[key]
2993+
if self.validate:
2994+
assert ts.actor
2995+
recommendations: Recs = {}
2996+
client_msgs: Msgs = {}
2997+
worker_msgs: Msgs = {}
2998+
# XXX factor this out?
2999+
worker_msg = {
3000+
"op": "free-keys",
3001+
"keys": [key],
3002+
"stimulus_id": stimulus_id,
3003+
}
3004+
for ws in ts.who_has or ():
3005+
worker_msgs[ws.address] = [worker_msg]
3006+
self.remove_all_replicas(ts)
3007+
3008+
for dts in ts.dependents:
3009+
if not dts.who_has:
3010+
dts.exception_blame = ts
3011+
recommendations[dts.key] = "erred"
3012+
exception = Serialized(
3013+
*serialize(RuntimeError("Worker holding Actor was lost"))
3014+
)
3015+
report_msg = {
3016+
"op": "task-erred",
3017+
"key": key,
3018+
"exception": exception,
3019+
}
3020+
for cs in ts.who_wants or ():
3021+
client_msgs[cs.client_key] = [report_msg]
3022+
3023+
ts.state = "erred"
3024+
return self._propagate_erred(
3025+
ts,
3026+
cause=ts.key,
3027+
exception=exception,
3028+
)
3029+
29913030
def _transition_memory_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs:
29923031
ts = self.tasks[key]
29933032

@@ -3078,6 +3117,7 @@ def _transition_released_forgotten(self, key: Key, stimulus_id: str) -> RecsMsgs
30783117
("no-worker", "processing"): _transition_no_worker_processing,
30793118
("no-worker", "erred"): _transition_no_worker_erred,
30803119
("released", "forgotten"): _transition_released_forgotten,
3120+
("memory", "erred"): _transition_memory_erred,
30813121
("memory", "forgotten"): _transition_memory_forgotten,
30823122
("erred", "released"): _transition_erred_released,
30833123
("memory", "released"): _transition_memory_released,
@@ -5521,7 +5561,9 @@ async def remove_worker(
55215561

55225562
for ts in list(ws.has_what):
55235563
self.remove_replica(ts, ws)
5524-
if not ts.who_has:
5564+
if ts in ws.actors:
5565+
recommendations[ts.key] = "erred"
5566+
elif not ts.who_has:
55255567
if ts.run_spec:
55265568
recompute_keys.add(ts.key)
55275569
recommendations[ts.key] = "released"

distributed/tests/test_actor.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Actor,
1414
BaseActorFuture,
1515
Client,
16+
Event,
1617
Future,
1718
Nanny,
1819
as_completed,
@@ -23,7 +24,7 @@
2324
from distributed.metrics import time
2425
from distributed.utils import LateLoopEvent
2526
from distributed.utils_test import cluster, double, gen_cluster, inc
26-
from distributed.worker import get_worker
27+
from distributed.worker import Worker, get_worker
2728

2829

2930
class Counter:
@@ -290,7 +291,7 @@ async def test_failed_worker(c, s, a, b):
290291

291292
await a.close()
292293

293-
with pytest.raises(ValueError, match="Worker holding Actor was lost"):
294+
with pytest.raises(RuntimeError, match="Worker holding Actor was lost"):
294295
await counter.increment()
295296

296297

@@ -824,3 +825,68 @@ def demo(self):
824825

825826
actor = await c.submit(Actor, actor=True, workers=[a.address])
826827
assert await actor.demo() == a.address
828+
829+
830+
@gen_cluster(client=True, nthreads=[("", 1)])
831+
async def test_actor_worker_host_leaves_gracefully(c, s, a):
832+
# see also test_actor_worker_host_dies
833+
async with Worker(s.address, nthreads=1) as b:
834+
835+
counter = await c.submit(
836+
Counter, actor=True, workers=[b.address], allow_other_workers=True
837+
)
838+
839+
enter_ev = Event()
840+
wait_ev = Event()
841+
842+
def foo(couner, enter_ev, wait_ev):
843+
enter_ev.set()
844+
wait_ev.wait()
845+
846+
fut = c.submit(
847+
foo,
848+
counter,
849+
enter_ev,
850+
wait_ev,
851+
workers=[a.address],
852+
allow_other_workers=True,
853+
)
854+
855+
await enter_ev.wait()
856+
await wait_ev.set()
857+
with pytest.raises(RuntimeError, match="Worker holding Actor was lost"):
858+
await fut.result()
859+
860+
861+
@gen_cluster(client=True, nthreads=[("", 1)])
862+
async def test_actor_worker_host_dies(c, s, a):
863+
# see also test_actor_worker_host_leaves_gracefully
864+
async with Worker(s.address, nthreads=1) as b:
865+
866+
counter = await c.submit(
867+
Counter, actor=True, workers=[b.address], allow_other_workers=True
868+
)
869+
870+
enter_ev = Event()
871+
wait_ev = Event()
872+
873+
def foo(couner, enter_ev, wait_ev):
874+
enter_ev.set()
875+
wait_ev.wait()
876+
877+
fut = c.submit(
878+
foo,
879+
counter,
880+
enter_ev,
881+
wait_ev,
882+
workers=[a.address],
883+
allow_other_workers=True,
884+
)
885+
886+
await enter_ev.wait()
887+
# Simulate the worker going down
888+
s.stream_comms[b.address].close()
889+
await b.finished()
890+
await wait_ev.set()
891+
with pytest.raises(RuntimeError, match="Worker holding Actor was lost"):
892+
await fut.result()

0 commit comments

Comments
 (0)