Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ def balance(self) -> None:
log = []
start = time()

# Pre-calculate all occupancies once, they don't change during balancing
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
combined_occupancy = partial(self._combined_occupancy, occupancies=occupancies)

i = 0
# Paused and closing workers must never become thieves
potential_thieves = set(s.idle.values())
Expand All @@ -434,21 +438,19 @@ def balance(self) -> None:
victim: WorkerState | None
potential_victims: set[WorkerState] | list[WorkerState] = s.saturated
if not potential_victims:
potential_victims = topk(
10, s.workers.values(), key=self._combined_occupancy
)
potential_victims = topk(10, s.workers.values(), key=combined_occupancy)
potential_victims = [
ws
for ws in potential_victims
if self._combined_occupancy(ws) > 0.2
if combined_occupancy(ws) > 0.2
and self._combined_nprocessing(ws) > ws.nthreads
and ws not in potential_thieves
]
if not potential_victims:
return
if len(potential_victims) < 20:
potential_victims = sorted(
potential_victims, key=self._combined_occupancy, reverse=True
potential_victims, key=combined_occupancy, reverse=True
)
assert potential_victims
assert potential_thieves
Expand All @@ -472,11 +474,15 @@ def balance(self) -> None:
stealable.discard(ts)
continue
i += 1
if not (thief := self._get_thief(s, ts, potential_thieves)):
if not (
thief := self._get_thief(
s, ts, potential_thieves, occupancies=occupancies
)
):
continue

occ_thief = self._combined_occupancy(thief)
occ_victim = self._combined_occupancy(victim)
occ_thief = combined_occupancy(thief)
occ_victim = combined_occupancy(victim)
comm_cost_thief = self.scheduler.get_comm_cost(ts, thief)
comm_cost_victim = self.scheduler.get_comm_cost(ts, victim)
compute = self.scheduler._get_prefix_duration(ts.prefix)
Expand All @@ -501,7 +507,7 @@ def balance(self) -> None:
self.metrics["request_count_total"][level] += 1
self.metrics["request_cost_total"][level] += cost

occ_thief = self._combined_occupancy(thief)
occ_thief = combined_occupancy(thief)
nproc_thief = self._combined_nprocessing(thief)

# FIXME: In the worst case, the victim may have 3x the amount of work
Expand All @@ -515,7 +521,7 @@ def balance(self) -> None:
# properly clean up, we would not need this
stealable.discard(ts)
self.scheduler.check_idle_saturated(
victim, occ=self._combined_occupancy(victim)
victim, occ=combined_occupancy(victim)
)

if log:
Expand All @@ -525,8 +531,10 @@ def balance(self) -> None:
if s.digests:
s.digests["steal-duration"].add(stop - start)

def _combined_occupancy(self, ws: WorkerState) -> float:
return ws.occupancy + self.in_flight_occupancy[ws]
def _combined_occupancy(
self, ws: WorkerState, *, occupancies: dict[WorkerState, float]
) -> float:
return occupancies[ws] + self.in_flight_occupancy[ws]

def _combined_nprocessing(self, ws: WorkerState) -> int:
return len(ws.processing) + self.in_flight_tasks[ws]
Expand All @@ -552,7 +560,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
out.append(t)
return out

def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...]:
def stealing_objective(
self, ts: TaskState, ws: WorkerState, *, occupancies: dict[WorkerState, float]
) -> tuple[float, ...]:
"""Objective function to determine which worker should get the task

Minimize expected start time. If a tie then break with data storage.
Expand All @@ -567,7 +577,8 @@ def stealing_objective(self, ts: TaskState, ws: WorkerState) -> tuple[float, ...
Scheduler.worker_objective
"""
occupancy = self._combined_occupancy(
ws
ws,
occupancies=occupancies,
) / ws.nthreads + self.scheduler.get_comm_cost(ts, ws)
if ts.actor:
return (len(ws.actors), occupancy, ws.nbytes)
Expand All @@ -579,6 +590,8 @@ def _get_thief(
scheduler: SchedulerState,
ts: TaskState,
potential_thieves: set[WorkerState],
*,
occupancies: dict[WorkerState, float],
) -> WorkerState | None:
valid_workers = scheduler.valid_workers(ts)
if valid_workers is not None:
Expand All @@ -587,7 +600,10 @@ def _get_thief(
potential_thieves = valid_thieves
elif not ts.loose_restrictions:
return None
return min(potential_thieves, key=partial(self.stealing_objective, ts))
return min(
potential_thieves,
key=partial(self.stealing_objective, ts, occupancies=occupancies),
)


fast_tasks = {
Expand Down
42 changes: 22 additions & 20 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1948,7 +1948,7 @@ async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers):
client=True,
config={"distributed.scheduler.worker-saturation": "inf"},
)
async def test_stealing_ogjective_accounts_for_in_flight(c, s, a):
async def test_stealing_objective_accounts_for_in_flight(c, s, a):
"""Regression test that work-stealing's objective correctly accounts for in-flight data requests"""
in_event = Event()
block_event = Event()
Expand All @@ -1973,32 +1973,34 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
wsB = s.workers[b.address]
ts = next(iter(wsA.processing))

occupancies = {ws: ws.occupancy for ws in s.workers.values()}
# No in-flight requests, so both match
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
ts, wsB
)
assert extension.stealing_objective(
ts, wsA, occupancies=occupancies
) == s.worker_objective(ts, wsA)
assert extension.stealing_objective(
ts, wsB, occupancies=occupancies
) == s.worker_objective(ts, wsB)

extension.balance()
assert extension.in_flight
# We move tasks from a to b
assert extension.stealing_objective(ts, wsA) < s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) > s.worker_objective(
ts, wsB
)
assert extension.stealing_objective(
ts, wsA, occupancies=occupancies
) < s.worker_objective(ts, wsA)
assert extension.stealing_objective(
ts, wsB, occupancies=occupancies
) > s.worker_objective(ts, wsB)

await async_poll_for(lambda: not extension.in_flight, timeout=5)
occupancies = {ws: ws.occupancy for ws in s.workers.values()}
# No in-flight requests, so both match
assert extension.stealing_objective(ts, wsA) == s.worker_objective(
ts, wsA
)
assert extension.stealing_objective(ts, wsB) == s.worker_objective(
ts, wsB
)
assert extension.stealing_objective(
ts, wsA, occupancies=occupancies
) == s.worker_objective(ts, wsA)
assert extension.stealing_objective(
ts, wsB, occupancies=occupancies
) == s.worker_objective(ts, wsB)
finally:
await block_event.set()
finally:
Expand Down Expand Up @@ -2031,7 +2033,7 @@ def block(i: int, in_event: Event, block_event: Event) -> int:
await in_event.wait()

# This is the pre-condition for the observed problem:
# There are tasks that execute fox a long time but do not have an average
# There are tasks that execute for a long time but do not have an average
s.task_prefixes["block"].add_exec_time(100)
assert s.task_prefixes["block"].duration_average == -1

Expand Down
Loading