@@ -426,6 +426,10 @@ def balance(self) -> None:
426426 log = []
427427 start = time ()
428428
429+ # Pre-calculate all occupancies once, they don't change during balancing
430+ occupancies = {ws : ws .occupancy for ws in s .workers .values ()}
431+ combined_occupancy = partial (self ._combined_occupancy , occupancies = occupancies )
432+
429433 i = 0
430434 # Paused and closing workers must never become thieves
431435 potential_thieves = set (s .idle .values ())
@@ -434,21 +438,19 @@ def balance(self) -> None:
434438 victim : WorkerState | None
435439 potential_victims : set [WorkerState ] | list [WorkerState ] = s .saturated
436440 if not potential_victims :
437- potential_victims = topk (
438- 10 , s .workers .values (), key = self ._combined_occupancy
439- )
441+ potential_victims = topk (10 , s .workers .values (), key = combined_occupancy )
440442 potential_victims = [
441443 ws
442444 for ws in potential_victims
443- if self . _combined_occupancy (ws ) > 0.2
445+ if combined_occupancy (ws ) > 0.2
444446 and self ._combined_nprocessing (ws ) > ws .nthreads
445447 and ws not in potential_thieves
446448 ]
447449 if not potential_victims :
448450 return
449451 if len (potential_victims ) < 20 :
450452 potential_victims = sorted (
451- potential_victims , key = self . _combined_occupancy , reverse = True
453+ potential_victims , key = combined_occupancy , reverse = True
452454 )
453455 assert potential_victims
454456 assert potential_thieves
@@ -472,11 +474,15 @@ def balance(self) -> None:
472474 stealable .discard (ts )
473475 continue
474476 i += 1
475- if not (thief := _get_thief (s , ts , potential_thieves )):
477+ if not (
478+ thief := self ._get_thief (
479+ s , ts , potential_thieves , occupancies = occupancies
480+ )
481+ ):
476482 continue
477483
478- occ_thief = self . _combined_occupancy (thief )
479- occ_victim = self . _combined_occupancy (victim )
484+ occ_thief = combined_occupancy (thief )
485+ occ_victim = combined_occupancy (victim )
480486 comm_cost_thief = self .scheduler .get_comm_cost (ts , thief )
481487 comm_cost_victim = self .scheduler .get_comm_cost (ts , victim )
482488 compute = self .scheduler ._get_prefix_duration (ts .prefix )
@@ -501,7 +507,7 @@ def balance(self) -> None:
501507 self .metrics ["request_count_total" ][level ] += 1
502508 self .metrics ["request_cost_total" ][level ] += cost
503509
504- occ_thief = self . _combined_occupancy (thief )
510+ occ_thief = combined_occupancy (thief )
505511 nproc_thief = self ._combined_nprocessing (thief )
506512
507513 # FIXME: In the worst case, the victim may have 3x the amount of work
@@ -515,7 +521,7 @@ def balance(self) -> None:
515521 # properly clean up, we would not need this
516522 stealable .discard (ts )
517523 self .scheduler .check_idle_saturated (
518- victim , occ = self . _combined_occupancy (victim )
524+ victim , occ = combined_occupancy (victim )
519525 )
520526
521527 if log :
@@ -525,8 +531,10 @@ def balance(self) -> None:
525531 if s .digests :
526532 s .digests ["steal-duration" ].add (stop - start )
527533
528- def _combined_occupancy (self , ws : WorkerState ) -> float :
529- return ws .occupancy + self .in_flight_occupancy [ws ]
534+ def _combined_occupancy (
535+ self , ws : WorkerState , * , occupancies : dict [WorkerState , float ]
536+ ) -> float :
537+ return occupancies [ws ] + self .in_flight_occupancy [ws ]
530538
531539 def _combined_nprocessing (self , ws : WorkerState ) -> int :
532540 return len (ws .processing ) + self .in_flight_tasks [ws ]
@@ -552,18 +560,50 @@ def story(self, *keys_or_ts: str | TaskState) -> list:
552560 out .append (t )
553561 return out
554562
563+ def stealing_objective (
564+ self , ts : TaskState , ws : WorkerState , * , occupancies : dict [WorkerState , float ]
565+ ) -> tuple [float , ...]:
566+ """Objective function to determine which worker should get the task
567+
568+ Minimize expected start time. If a tie then break with data storage.
555569
556- def _get_thief (
557- scheduler : SchedulerState , ts : TaskState , potential_thieves : set [WorkerState ]
558- ) -> WorkerState | None :
559- valid_workers = scheduler .valid_workers (ts )
560- if valid_workers is not None :
561- valid_thieves = potential_thieves & valid_workers
562- if valid_thieves :
563- potential_thieves = valid_thieves
564- elif not ts .loose_restrictions :
565- return None
566- return min (potential_thieves , key = partial (scheduler .worker_objective , ts ))
570+ Notes
571+ -----
572+ This method is a modified version of Scheduler.worker_objective that accounts
573+ for in-flight requests. It must be kept in sync for work-stealing to work correctly.
574+
575+ See Also
576+ --------
577+ Scheduler.worker_objective
578+ """
579+ occupancy = self ._combined_occupancy (
580+ ws ,
581+ occupancies = occupancies ,
582+ ) / ws .nthreads + self .scheduler .get_comm_cost (ts , ws )
583+ if ts .actor :
584+ return (len (ws .actors ), occupancy , ws .nbytes )
585+ else :
586+ return (occupancy , ws .nbytes )
587+
588+ def _get_thief (
589+ self ,
590+ scheduler : SchedulerState ,
591+ ts : TaskState ,
592+ potential_thieves : set [WorkerState ],
593+ * ,
594+ occupancies : dict [WorkerState , float ],
595+ ) -> WorkerState | None :
596+ valid_workers = scheduler .valid_workers (ts )
597+ if valid_workers is not None :
598+ valid_thieves = potential_thieves & valid_workers
599+ if valid_thieves :
600+ potential_thieves = valid_thieves
601+ elif not ts .loose_restrictions :
602+ return None
603+ return min (
604+ potential_thieves ,
605+ key = partial (self .stealing_objective , ts , occupancies = occupancies ),
606+ )
567607
568608
569609fast_tasks = {
0 commit comments