Skip to content

Commit 6f6c36a

Browse files
committed
Only handle long-running message if task if sent by the worker designated to process the task
1 parent 09855ee commit 6f6c36a

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

distributed/scheduler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6064,9 +6064,16 @@ def handle_long_running(
60646064
We stop the task from being stolen in the future, and change task
60656065
duration accounting as if the task has stopped.
60666066
"""
6067+
if worker not in self.workers:
6068+
logger.debug(
6069+
"Received long-running signal from unknown worker %s. Ignoring.", worker
6070+
)
6071+
return
6072+
60676073
if key not in self.tasks:
60686074
logger.debug("Skipping long_running since key %s was already released", key)
60696075
return
6076+
60706077
ts = self.tasks[key]
60716078
steal = self.extensions.get("stealing")
60726079
if steal is not None:
@@ -6077,6 +6084,14 @@ def handle_long_running(
60776084
logger.debug("Received long-running signal from duplicate task. Ignoring.")
60786085
return
60796086

6087+
if ws.address != worker:
6088+
logger.debug(
6089+
"Received stale long-running signal from worker %s for task %s. Ignoring.",
6090+
worker,
6091+
ts,
6092+
)
6093+
return
6094+
60806095
if compute_duration is not None:
60816096
old_duration = ts.prefix.duration_average
60826097
if old_duration < 0:

distributed/tests/test_cancelled_state.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import logging
45

56
import pytest
67

@@ -14,6 +15,7 @@
1415
_LockedCommPool,
1516
assert_story,
1617
async_poll_for,
18+
captured_logger,
1719
freeze_batched_send,
1820
gen_cluster,
1921
inc,
@@ -1292,6 +1294,121 @@ def test_secede_cancelled_or_resumed_workerstate(
12921294
assert ts not in ws.long_running
12931295

12941296

1297+
@gen_cluster(client=True, nthreads=[("", 1), ("", 1)], timeout=2)
1298+
async def test_secede_racing_cancellation_and_scheduling_on_other_worker(c, s, a, b):
1299+
wsA = s.workers[a.address]
1300+
in_long_running = Event()
1301+
block_secede = Event()
1302+
seceded = Event()
1303+
block_long_running = Event()
1304+
handled_long_running = Event()
1305+
1306+
def f(ev1, block_secede, seceded, block_long_running):
1307+
in_long_running.set()
1308+
block_secede.wait()
1309+
distributed.secede()
1310+
seceded.set()
1311+
block_long_running.wait()
1312+
return 123
1313+
1314+
# Instrument long-running handler
1315+
original_handler = s.stream_handlers["long-running"]
1316+
1317+
async def instrumented_handle_long_running(*args, **kwargs):
1318+
try:
1319+
return original_handler(*args, **kwargs)
1320+
finally:
1321+
await handled_long_running.set()
1322+
1323+
s.stream_handlers["long-running"] = instrumented_handle_long_running
1324+
1325+
# Submit task and wait until it executes on a
1326+
x = c.submit(
1327+
f,
1328+
in_long_running,
1329+
block_secede,
1330+
seceded,
1331+
block_long_running,
1332+
key="x",
1333+
workers=[a.address],
1334+
)
1335+
await in_long_running.wait()
1336+
ts = a.state.tasks["x"]
1337+
assert ts.state == "executing"
1338+
assert wsA.processing
1339+
assert not wsA.long_running
1340+
1341+
with captured_logger("distributed.scheduler", logging.ERROR) as caplog:
1342+
with freeze_batched_send(a.batched_stream):
1343+
# Let x secede (and later succeed) without informing the scheduler
1344+
await block_secede.set()
1345+
await wait_for_state("x", "long-running", a)
1346+
assert not a.state.executing
1347+
assert a.state.long_running
1348+
await block_long_running.set()
1349+
1350+
await wait_for_state("x", "memory", a)
1351+
1352+
# Cancel x while the scheduler does not know that it seceded
1353+
x.release()
1354+
await async_poll_for(lambda: not s.tasks, timeout=5)
1355+
assert not wsA.processing
1356+
assert not wsA.long_running
1357+
1358+
# Reset all events
1359+
await in_long_running.clear()
1360+
await block_secede.clear()
1361+
await seceded.clear()
1362+
await block_long_running.clear()
1363+
1364+
# Resubmit task and wait until it executes on b
1365+
x = c.submit(
1366+
f,
1367+
in_long_running,
1368+
block_secede,
1369+
seceded,
1370+
block_long_running,
1371+
key="x",
1372+
workers=[b.address],
1373+
)
1374+
await in_long_running.wait()
1375+
ts = b.state.tasks["x"]
1376+
wsB = s.workers[b.address]
1377+
assert ts.state == "executing"
1378+
assert wsB.processing
1379+
assert not wsB.long_running
1380+
1381+
# Unblock the stream from a to the scheduler and handle the long-running message
1382+
await handled_long_running.wait()
1383+
assert ts.state == "executing"
1384+
1385+
assert wsB.processing
1386+
assert wsB.task_prefix_count
1387+
assert not wsB.long_running
1388+
1389+
assert not wsA.processing
1390+
assert not wsA.task_prefix_count
1391+
assert not wsA.long_running
1392+
1393+
# Clear the handler and let x secede on b
1394+
await handled_long_running.clear()
1395+
1396+
await block_secede.set()
1397+
await wait_for_state("x", "long-running", b)
1398+
1399+
assert not b.state.executing
1400+
assert b.state.long_running
1401+
await handled_long_running.wait()
1402+
1403+
# Assert that the handler did not fail and no state was corrupted
1404+
logs = caplog.getvalue()
1405+
assert not logs
1406+
assert not wsB.task_prefix_count
1407+
1408+
await block_long_running.set()
1409+
assert await x.result() == 123
1410+
1411+
12951412
@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
12961413
async def test_secede_cancelled_or_resumed_scheduler(c, s, a):
12971414
"""Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction

0 commit comments

Comments
 (0)