Skip to content

Commit 1e1fdf9

Browse files
committed
Comments and test
1 parent 68bdf8b commit 1e1fdf9

File tree

3 files changed

+72
-3
lines changed

3 files changed

+72
-3
lines changed

distributed/scheduler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3283,6 +3283,10 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
32833283
"""Objective function to determine which worker should get the task
32843284
32853285
Minimize expected start time. If a tie then break with data storage.
3286+
3287+
See Also
3288+
--------
3289+
WorkStealing.stealing_objective
32863290
"""
32873291
stack_time = ws.occupancy / ws.nthreads
32883292
start_time = stack_time + self.get_comm_cost(ts, ws)

distributed/stealing.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,11 +529,24 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
529529
return out
530530

531531
def stealing_objective(
532-
self, scheduler: SchedulerState, ts: TaskState, ws: WorkerState
532+
self, ts: TaskState, ws: WorkerState
533533
) -> tuple[float, ...]:
534+
"""Objective function to determine which worker should get the task
535+
536+
Minimize expected start time. If a tie then break with data storage.
537+
538+
Notes
539+
-----
540+
This method is a modified version of Scheduler.worker_objective that accounts
541+
for in-flight requests. It must be kept in sync for work-stealing to work correctly.
542+
543+
See Also
544+
--------
545+
Scheduler.worker_objective
546+
"""
534547
occupancy = self._combined_occupancy(
535548
ws
536-
) / ws.nthreads + scheduler.get_comm_cost(ts, ws)
549+
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
537550
if ts.actor:
538551
return (len(ws.actors), occupancy, ws.nbytes)
539552
else:
@@ -553,7 +566,7 @@ def _get_thief(
553566
elif not ts.loose_restrictions:
554567
return None
555568
return min(
556-
potential_thieves, key=partial(self.stealing_objective, scheduler, ts)
569+
potential_thieves, key=partial(self.stealing_objective, ts)
557570
)
558571

559572

distributed/tests/test_steal.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from distributed.utils_test import (
3939
NO_AMM,
4040
BlockedGetData,
41+
async_poll_for,
4142
captured_logger,
4243
freeze_batched_send,
4344
gen_cluster,
@@ -1877,3 +1878,54 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
18771878
await c.gather(futs)
18781879
events = s.get_events("stealing")
18791880
assert len(events) == 0
1881+
1882+
1883+
@gen_cluster(
1884+
nthreads=[("", 1)],
1885+
client=True,
1886+
config={"distributed.scheduler.worker-saturation": "inf"},
1887+
)
1888+
async def test_stealing_ogjective_accounts_for_in_flight(c, s, a):
1889+
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests
1890+
"""
1891+
in_event = Event()
1892+
block_event = Event()
1893+
1894+
def block(i: int, in_event: Event, block_event: Event) -> int:
1895+
in_event.set()
1896+
block_event.wait()
1897+
return i
1898+
1899+
# Stop stealing for deterministic testing
1900+
extension = s.extensions["stealing"]
1901+
await extension.stop()
1902+
1903+
try:
1904+
futs = c.map(block, range(20), in_event=in_event, block_event=block_event)
1905+
await in_event.wait()
1906+
1907+
async with Worker(s.address, nthreads=1) as b:
1908+
try:
1909+
await async_poll_for(lambda: s.idle, timeout=5)
1910+
wsA = s.workers[a.address]
1911+
wsB = s.workers[b.address]
1912+
ts = next(iter(wsA.processing))
1913+
1914+
# No in-flight requests, so both match
1915+
assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA)
1916+
assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB)
1917+
1918+
extension.balance()
1919+
assert extension.in_flight
1920+
# We move tasks from a to b
1921+
assert extension.stealing_objective(ts, wsA) < s.worker_objective(ts, wsA)
1922+
assert extension.stealing_objective(ts, wsB) > s.worker_objective(ts, wsB)
1923+
1924+
await async_poll_for(lambda: not extension.in_flight, timeout=5)
1925+
# No in-flight requests, so both match
1926+
assert extension.stealing_objective(ts, wsA) == s.worker_objective(ts, wsA)
1927+
assert extension.stealing_objective(ts, wsB) == s.worker_objective(ts, wsB)
1928+
finally:
1929+
await block_event.set()
1930+
finally:
1931+
await block_event.set()

0 commit comments

Comments
 (0)