11from __future__ import annotations
22
33import asyncio
4+ import logging
45
56import pytest
67
1415 _LockedCommPool ,
1516 assert_story ,
1617 async_poll_for ,
18+ captured_logger ,
1719 freeze_batched_send ,
1820 gen_cluster ,
1921 inc ,
@@ -903,13 +905,15 @@ def test_workerstate_executing_to_executing(ws_with_running_task):
903905
904906 instructions = ws .handle_stimulus (
905907 FreeKeysEvent (keys = ["x" ], stimulus_id = "s1" ),
906- ComputeTaskEvent .dummy ("x" , resource_restrictions = {"R" : 1 }, stimulus_id = "s2" ),
908+ ComputeTaskEvent .dummy (
909+ "x" , run_id = 0 , resource_restrictions = {"R" : 1 }, stimulus_id = "s2"
910+ ),
907911 )
908912 if prev_state == "executing" :
909913 assert not instructions
910914 else :
911915 assert instructions == [
912- LongRunningMsg (key = "x" , compute_duration = None , stimulus_id = "s2" )
916+ LongRunningMsg (key = "x" , run_id = 0 , compute_duration = None , stimulus_id = "s2" )
913917 ]
914918 assert ws .tasks ["x" ] is ts
915919 assert ts .state == prev_state
@@ -1087,15 +1091,17 @@ def test_workerstate_resumed_fetch_to_cancelled_to_executing(ws_with_running_tas
10871091
10881092 instructions = ws .handle_stimulus (
10891093 FreeKeysEvent (keys = ["x" ], stimulus_id = "s1" ),
1090- ComputeTaskEvent .dummy ("y" , who_has = {"x" : [ws2 ]}, stimulus_id = "s2" ),
1094+ ComputeTaskEvent .dummy ("y" , run_id = 0 , who_has = {"x" : [ws2 ]}, stimulus_id = "s2" ),
10911095 FreeKeysEvent (keys = ["y" , "x" ], stimulus_id = "s3" ),
1092- ComputeTaskEvent .dummy ("x" , resource_restrictions = {"R" : 1 }, stimulus_id = "s4" ),
1096+ ComputeTaskEvent .dummy (
1097+ "x" , run_id = 1 , resource_restrictions = {"R" : 1 }, stimulus_id = "s4"
1098+ ),
10931099 )
10941100 if prev_state == "executing" :
10951101 assert not instructions
10961102 else :
10971103 assert instructions == [
1098- LongRunningMsg (key = "x" , compute_duration = None , stimulus_id = "s4" )
1104+ LongRunningMsg (key = "x" , run_id = 1 , compute_duration = None , stimulus_id = "s4" )
10991105 ]
11001106 assert ws .tasks ["x" ].state == prev_state
11011107
@@ -1111,16 +1117,16 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task):
11111117 # x is released for whatever reason (e.g. client cancellation)
11121118 FreeKeysEvent (keys = ["x" ], stimulus_id = "s1" ),
11131119 # x was computed somewhere else
1114- ComputeTaskEvent .dummy ("y" , who_has = {"x" : [ws2 ]}, stimulus_id = "s2" ),
1120+ ComputeTaskEvent .dummy ("y" , run_id = 0 , who_has = {"x" : [ws2 ]}, stimulus_id = "s2" ),
11151121 # x was lost / no known replicas, therefore y is cancelled
11161122 FreeKeysEvent (keys = ["y" ], stimulus_id = "s3" ),
1117- ComputeTaskEvent .dummy ("x" , stimulus_id = "s4" ),
1123+ ComputeTaskEvent .dummy ("x" , run_id = 1 , stimulus_id = "s4" ),
11181124 )
11191125 if prev_state == "executing" :
11201126 assert not instructions
11211127 else :
11221128 assert instructions == [
1123- LongRunningMsg (key = "x" , compute_duration = None , stimulus_id = "s4" )
1129+ LongRunningMsg (key = "x" , run_id = 1 , compute_duration = None , stimulus_id = "s4" )
11241130 ]
11251131 assert len (ws .tasks ) == 1
11261132 assert ws .tasks ["x" ].state == prev_state
@@ -1254,12 +1260,14 @@ def test_secede_cancelled_or_resumed_workerstate(
12541260 """
12551261 ws2 = "127.0.0.1:2"
12561262 ws .handle_stimulus (
1257- ComputeTaskEvent .dummy ("x" , stimulus_id = "s1" ),
1263+ ComputeTaskEvent .dummy ("x" , run_id = 0 , stimulus_id = "s1" ),
12581264 FreeKeysEvent (keys = ["x" ], stimulus_id = "s2" ),
12591265 )
12601266 if resume_to_fetch :
12611267 ws .handle_stimulus (
1262- ComputeTaskEvent .dummy ("y" , who_has = {"x" : [ws2 ]}, stimulus_id = "s3" ),
1268+ ComputeTaskEvent .dummy (
1269+ "y" , run_id = 1 , who_has = {"x" : [ws2 ]}, stimulus_id = "s3"
1270+ ),
12631271 )
12641272 ts = ws .tasks ["x" ]
12651273 assert ts .previous == "executing"
@@ -1277,11 +1285,11 @@ def test_secede_cancelled_or_resumed_workerstate(
12771285 if resume_to_executing :
12781286 instructions = ws .handle_stimulus (
12791287 FreeKeysEvent (keys = ["y" ], stimulus_id = "s5" ),
1280- ComputeTaskEvent .dummy ("x" , stimulus_id = "s6" ),
1288+ ComputeTaskEvent .dummy ("x" , run_id = 2 , stimulus_id = "s6" ),
12811289 )
12821290 # Inform the scheduler of the SecedeEvent that happened in the past
12831291 assert instructions == [
1284- LongRunningMsg (key = "x" , compute_duration = None , stimulus_id = "s6" )
1292+ LongRunningMsg (key = "x" , run_id = 2 , compute_duration = None , stimulus_id = "s6" )
12851293 ]
12861294 assert ts .state == "long-running"
12871295 assert ts not in ws .executing
@@ -1292,6 +1300,223 @@ def test_secede_cancelled_or_resumed_workerstate(
12921300 assert ts not in ws .long_running
12931301
12941302
1303+ @gen_cluster (client = True , nthreads = [("" , 1 ), ("" , 1 )], timeout = 2 )
1304+ async def test_secede_racing_cancellation_and_scheduling_on_other_worker (c , s , a , b ):
1305+ """Regression test that ensures that we handle stale long-running messages correctly.
1306+
1307+ This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resubmitted to
1308+ run on worker b. The long-running message created on a only arrives on the scheduler after the task started executing on b
1309+ (but before a secede event arrives from worker b). The scheduler should then ignore the stale secede event from a.
1310+ """
1311+ wsA = s .workers [a .address ]
1312+ before_secede = Event ()
1313+ block_secede = Event ()
1314+ block_long_running = Event ()
1315+ handled_long_running = Event ()
1316+
1317+ def f (before_secede , block_secede , block_long_running ):
1318+ before_secede .set ()
1319+ block_secede .wait ()
1320+ distributed .secede ()
1321+ block_long_running .wait ()
1322+ return 123
1323+
1324+ # Instrument long-running handler
1325+ original_handler = s .stream_handlers ["long-running" ]
1326+
1327+ async def instrumented_handle_long_running (* args , ** kwargs ):
1328+ try :
1329+ return original_handler (* args , ** kwargs )
1330+ finally :
1331+ await handled_long_running .set ()
1332+
1333+ s .stream_handlers ["long-running" ] = instrumented_handle_long_running
1334+
1335+ # Submit task and wait until it executes on a
1336+ x = c .submit (
1337+ f ,
1338+ before_secede ,
1339+ block_secede ,
1340+ block_long_running ,
1341+ key = "x" ,
1342+ workers = [a .address ],
1343+ )
1344+ await before_secede .wait ()
1345+
1346+ # FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails.
1347+ with captured_logger ("distributed.scheduler" , logging .ERROR ) as caplog :
1348+ with freeze_batched_send (a .batched_stream ):
1349+ # Let x secede (and later succeed) without informing the scheduler
1350+ await block_secede .set ()
1351+ await wait_for_state ("x" , "long-running" , a )
1352+ assert not a .state .executing
1353+ assert a .state .long_running
1354+ await block_long_running .set ()
1355+
1356+ await wait_for_state ("x" , "memory" , a )
1357+
1358+ # Cancel x while the scheduler does not know that it seceded
1359+ x .release ()
1360+ await async_poll_for (lambda : not s .tasks , timeout = 5 )
1361+ assert not wsA .processing
1362+ assert not wsA .long_running
1363+
1364+ # Reset all events
1365+ await before_secede .clear ()
1366+ await block_secede .clear ()
1367+ await block_long_running .clear ()
1368+
1369+ # Resubmit task and wait until it executes on b
1370+ x = c .submit (
1371+ f ,
1372+ before_secede ,
1373+ block_secede ,
1374+ block_long_running ,
1375+ key = "x" ,
1376+ workers = [b .address ],
1377+ )
1378+ await before_secede .wait ()
1379+ wsB = s .workers [b .address ]
1380+ assert wsB .processing
1381+ assert not wsB .long_running
1382+
1383+ # Unblock the stream from a to the scheduler and handle the long-running message
1384+ await handled_long_running .wait ()
1385+ ts = b .state .tasks ["x" ]
1386+ assert ts .state == "executing"
1387+
1388+ assert wsB .processing
1389+ assert wsB .task_prefix_count
1390+ assert not wsB .long_running
1391+
1392+ assert not wsA .processing
1393+ assert not wsA .task_prefix_count
1394+ assert not wsA .long_running
1395+
1396+ # Clear the handler and let x secede on b
1397+ await handled_long_running .clear ()
1398+
1399+ await block_secede .set ()
1400+ await wait_for_state ("x" , "long-running" , b )
1401+
1402+ assert not b .state .executing
1403+ assert b .state .long_running
1404+ await handled_long_running .wait ()
1405+
1406+ # Assert that the handler did not fail and no state was corrupted
1407+ logs = caplog .getvalue ()
1408+ assert not logs
1409+ assert not wsB .task_prefix_count
1410+
1411+ await block_long_running .set ()
1412+ assert await x .result () == 123
1413+
1414+
1415+ @gen_cluster (client = True , nthreads = [("" , 1 )], timeout = 2 )
1416+ async def test_secede_racing_resuming_on_same_worker (c , s , a ):
1417+ """Regression test that ensures that we handle stale long-running messages correctly.
1418+
1419+ This tests simulates a race condition where a task secedes on worker a, the task is then cancelled, and resumed on
1420+ worker a. The first long-running message created on a only arrives on the scheduler after the task was resumed.
1421+ The scheduler should then ignore the stale first secede event from a and only handle the second one.
1422+ """
1423+ wsA = s .workers [a .address ]
1424+ before_secede = Event ()
1425+ block_secede = Event ()
1426+ block_long_running = Event ()
1427+ handled_long_running = Event ()
1428+ block_long_running_handler = Event ()
1429+
1430+ def f (before_secede , block_secede , block_long_running ):
1431+ before_secede .set ()
1432+ block_secede .wait ()
1433+ distributed .secede ()
1434+ block_long_running .wait ()
1435+ return 123
1436+
1437+ # Instrument long-running handler
1438+ original_handler = s .stream_handlers ["long-running" ]
1439+ block_second_attempt = None
1440+
1441+ async def instrumented_handle_long_running (* args , ** kwargs ):
1442+ nonlocal block_second_attempt
1443+
1444+ if block_second_attempt is None :
1445+ block_second_attempt = True
1446+ elif block_second_attempt is True :
1447+ await block_long_running_handler .wait ()
1448+ block_second_attempt = False
1449+ try :
1450+ return original_handler (* args , ** kwargs )
1451+ finally :
1452+ await block_long_running_handler .clear ()
1453+ await handled_long_running .set ()
1454+
1455+ s .stream_handlers ["long-running" ] = instrumented_handle_long_running
1456+
1457+ # Submit task and wait until it executes on a
1458+ x = c .submit (
1459+ f ,
1460+ before_secede ,
1461+ block_secede ,
1462+ block_long_running ,
1463+ key = "x" ,
1464+ )
1465+ await before_secede .wait ()
1466+
1467+ # FIXME: Relying on logging is rather brittle. We should fail hard if stimulus handling fails.
1468+ with captured_logger ("distributed.scheduler" , logging .ERROR ) as caplog :
1469+ with freeze_batched_send (a .batched_stream ):
1470+ # Let x secede (and later succeed) without informing the scheduler
1471+ await block_secede .set ()
1472+ await wait_for_state ("x" , "long-running" , a )
1473+ assert not a .state .executing
1474+ assert a .state .long_running
1475+
1476+ # Cancel x while the scheduler does not know that it seceded
1477+ x .release ()
1478+ await async_poll_for (lambda : not s .tasks , timeout = 5 )
1479+ assert not wsA .processing
1480+ assert not wsA .long_running
1481+
1482+ # Resubmit task and wait until it is resumed on a
1483+ x = c .submit (
1484+ f ,
1485+ before_secede ,
1486+ block_secede ,
1487+ block_long_running ,
1488+ key = "x" ,
1489+ )
1490+ await wait_for_state ("x" , "long-running" , a )
1491+ assert not a .state .executing
1492+ assert a .state .long_running
1493+
1494+ assert wsA .processing
1495+ assert not wsA .long_running
1496+
1497+ # Unblock the stream from a to the scheduler and handle the stale long-running message
1498+ await handled_long_running .wait ()
1499+
1500+ assert wsA .processing
1501+ assert wsA .task_prefix_count
1502+ assert not wsA .long_running
1503+
1504+ # Clear the handler and let the scheduler handle the second long-running message
1505+ await handled_long_running .clear ()
1506+ await block_long_running_handler .set ()
1507+ await handled_long_running .wait ()
1508+
1509+ # Assert that the handler did not fail and no state was corrupted
1510+ logs = caplog .getvalue ()
1511+ assert not logs
1512+ assert not wsA .task_prefix_count
1513+ assert wsA .processing
1514+ assert wsA .long_running
1515+
1516+ await block_long_running .set ()
1517+ assert await x .result () == 123
1518+
1519+
12951520@gen_cluster (client = True , nthreads = [("" , 1 )], timeout = 2 )
12961521async def test_secede_cancelled_or_resumed_scheduler (c , s , a ):
12971522 """Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction
0 commit comments