Skip to content

Commit d709a10

Browse files
committed
Cache occupancy in balancing
1 parent b86b714 commit d709a10

File tree

2 files changed

+127
-23
lines changed

2 files changed

+127
-23
lines changed

distributed/stealing.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ def balance(self) -> None:
426426
log = []
427427
start = time()
428428

429+
# Pre-calculate all occupancies once, they don't change during balancing
430+
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
431+
combined_occupancy = partial(self._combined_occupancy, occupancies=occupancies)
432+
429433
i = 0
430434
# Paused and closing workers must never become thieves
431435
potential_thieves = set(s.idle.values())
@@ -434,21 +438,19 @@ def balance(self) -> None:
434438
victim: WorkerState | None
435439
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
436440
if not potential_victims:
437-
potential_victims = topk(
438-
10, s.workers.values(), key=self._combined_occupancy
439-
)
441+
potential_victims = topk(10, s.workers.values(), key=combined_occupancy)
440442
potential_victims = [
441443
ws
442444
for ws in potential_victims
443-
if self._combined_occupancy(ws) > 0.2
445+
if combined_occupancy(ws) > 0.2
444446
and self._combined_nprocessing(ws) > ws.nthreads
445447
and ws not in potential_thieves
446448
]
447449
if not potential_victims:
448450
return
449451
if len(potential_victims) < 20:
450452
potential_victims = sorted(
451-
potential_victims, key=self._combined_occupancy, reverse=True
453+
potential_victims, key=combined_occupancy, reverse=True
452454
)
453455
assert potential_victims
454456
assert potential_thieves
@@ -472,11 +474,15 @@ def balance(self) -> None:
472474
stealable.discard(ts)
473475
continue
474476
i += 1
475-
if not (thief := _get_thief(s, ts, potential_thieves)):
477+
if not (
478+
thief := self._get_thief(
479+
s, ts, potential_thieves, occupancies=occupancies
480+
)
481+
):
476482
continue
477483

478-
occ_thief = self._combined_occupancy(thief)
479-
occ_victim = self._combined_occupancy(victim)
484+
occ_thief = combined_occupancy(thief)
485+
occ_victim = combined_occupancy(victim)
480486
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
481487
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
482488
compute = self.scheduler._get_prefix_duration(ts.prefix)
@@ -501,7 +507,7 @@ def balance(self) -> None:
501507
self.metrics["request_count_total"][level] += 1
502508
self.metrics["request_cost_total"][level] += cost
503509

504-
occ_thief = self._combined_occupancy(thief)
510+
occ_thief = combined_occupancy(thief)
505511
nproc_thief = self._combined_nprocessing(thief)
506512

507513
# FIXME: In the worst case, the victim may have 3x the amount of work
@@ -515,7 +521,7 @@ def balance(self) -> None:
515521
# properly clean up, we would not need this
516522
stealable.discard(ts)
517523
self.scheduler.check_idle_saturated(
518-
victim, occ=self._combined_occupancy(victim)
524+
victim, occ=combined_occupancy(victim)
519525
)
520526

521527
if log:
@@ -525,8 +531,10 @@ def balance(self) -> None:
525531
if s.digests:
526532
s.digests["steal-duration"].add(stop - start)
527533

528-
def _combined_occupancy(self, ws: WorkerState) -> float:
529-
return ws.occupancy + self.in_flight_occupancy[ws]
534+
def _combined_occupancy(
535+
self, ws: WorkerState, *, occupancies: dict[WorkerState, float]
536+
) -> float:
537+
return occupancies[ws] + self.in_flight_occupancy[ws]
530538

531539
def _combined_nprocessing(self, ws: WorkerState) -> int:
532540
return len(ws.processing) + self.in_flight_tasks[ws]
@@ -552,18 +560,50 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
552560
out.append(t)
553561
return out
554562

563+
def stealing_objective(
564+
self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float]
565+
) -> tuple[float, ...]:
566+
"""Objective function to determine which worker should get the task
567+
568+
Minimize expected start time. If a tie then break with data storage.
555569
556-
def _get_thief(
557-
scheduler: SchedulerState, ts: TaskState, potential_thieves: set[WorkerState]
558-
) -> WorkerState | None:
559-
valid_workers = scheduler.valid_workers(ts)
560-
if valid_workers is not None:
561-
valid_thieves = potential_thieves & valid_workers
562-
if valid_thieves:
563-
potential_thieves = valid_thieves
564-
elif not ts.loose_restrictions:
565-
return None
566-
return min(potential_thieves, key=partial(scheduler.worker_objective, ts))
570+
Notes
571+
-----
572+
This method is a modified version of Scheduler.worker_objective that accounts
573+
for in-flight requests. It must be kept in sync for work-stealing to work correctly.
574+
575+
See Also
576+
--------
577+
Scheduler.worker_objective
578+
"""
579+
occupancy = self._combined_occupancy(
580+
ws,
581+
occupancies=occupancies,
582+
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
583+
if ts.actor:
584+
return (len(ws.actors), occupancy, ws.nbytes)
585+
else:
586+
return (occupancy, ws.nbytes)
587+
588+
def _get_thief(
589+
self,
590+
scheduler: SchedulerState,
591+
ts: TaskState,
592+
potential_thieves: set[WorkerState],
593+
*,
594+
occupancies: dict[WorkerState, float],
595+
) -> WorkerState | None:
596+
valid_workers = scheduler.valid_workers(ts)
597+
if valid_workers is not None:
598+
valid_thieves = potential_thieves & valid_workers
599+
if valid_thieves:
600+
potential_thieves = valid_thieves
601+
elif not ts.loose_restrictions:
602+
return None
603+
return min(
604+
potential_thieves,
605+
key=partial(self.stealing_objective, ts, occupancies=occupancies),
606+
)
567607

568608

569609
fast_tasks = {

distributed/tests/test_steal.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,70 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
19431943
assert len(events) == 0
19441944

19451945

1946+
@gen_cluster(
1947+
nthreads=[("", 1)],
1948+
client=True,
1949+
config={"distributed.scheduler.worker-saturation": "inf"},
1950+
)
1951+
async def test_stealing_objective_accounts_for_in_flight(c, s, a):
1952+
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests"""
1953+
in_event = Event()
1954+
block_event = Event()
1955+
1956+
def block(i: int, in_event: Event, block_event: Event) -> int:
1957+
in_event.set()
1958+
block_event.wait()
1959+
return i
1960+
1961+
# Stop stealing for deterministic testing
1962+
extension = s.extensions["stealing"]
1963+
await extension.stop()
1964+
1965+
try:
1966+
futs = c.map(block, range(20), in_event=in_event, block_event=block_event)
1967+
await in_event.wait()
1968+
1969+
async with Worker(s.address, nthreads=1) as b:
1970+
try:
1971+
await async_poll_for(lambda: s.idle, timeout=5)
1972+
wsA = s.workers[a.address]
1973+
wsB = s.workers[b.address]
1974+
ts = next(iter(wsA.processing))
1975+
1976+
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
1977+
# No in-flight requests, so both match
1978+
assert extension.stealing_objective(
1979+
ts, wsA, occupancies=occupancies
1980+
) == s.worker_objective(ts, wsA)
1981+
assert extension.stealing_objective(
1982+
ts, wsB, occupancies=occupancies
1983+
) == s.worker_objective(ts, wsB)
1984+
1985+
extension.balance()
1986+
assert extension.in_flight
1987+
# We move tasks from a to b
1988+
assert extension.stealing_objective(
1989+
ts, wsA, occupancies=occupancies
1990+
) < s.worker_objective(ts, wsA)
1991+
assert extension.stealing_objective(
1992+
ts, wsB, occupancies=occupancies
1993+
) > s.worker_objective(ts, wsB)
1994+
1995+
await async_poll_for(lambda: not extension.in_flight, timeout=5)
1996+
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
1997+
# No in-flight requests, so both match
1998+
assert extension.stealing_objective(
1999+
ts, wsA, occupancies=occupancies
2000+
) == s.worker_objective(ts, wsA)
2001+
assert extension.stealing_objective(
2002+
ts, wsB, occupancies=occupancies
2003+
) == s.worker_objective(ts, wsB)
2004+
finally:
2005+
await block_event.set()
2006+
finally:
2007+
await block_event.set()
2008+
2009+
19462010
@gen_cluster(
19472011
nthreads=[("", 1)],
19482012
client=True,

0 commit comments

Comments
 (0)