@@ -501,7 +501,7 @@ class WorkerState:
501501 # Reference to scheduler task_groups
502502 scheduler_ref : weakref .ref [SchedulerState ] | None
503503 task_prefix_count : defaultdict [str , int ]
504- _network_occ : float
504+ _network_occ : int
505505 _occupancy_cache : float | None
506506
507507 #: Keys that may need to be fetched to this worker, and the number of tasks that need them.
@@ -822,8 +822,11 @@ def _dec_needs_replica(self, ts: TaskState) -> None:
822822 if self .needs_what [ts ] == 0 :
823823 del self .needs_what [ts ]
824824 nbytes = ts .get_nbytes ()
825- self ._network_occ -= nbytes
826- self .scheduler ._network_occ_global -= nbytes
825+ # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
826+ self ._network_occ -= min (nbytes , self ._network_occ )
827+ self .scheduler ._network_occ_global -= min (
828+ nbytes , self .scheduler ._network_occ_global
829+ )
827830
828831 def add_replica (self , ts : TaskState ) -> None :
829832 """The worker acquired a replica of task"""
@@ -834,8 +837,11 @@ def add_replica(self, ts: TaskState) -> None:
834837 nbytes = ts .get_nbytes ()
835838 if ts in self .needs_what :
836839 del self .needs_what [ts ]
837- self ._network_occ -= nbytes
838- self .scheduler ._network_occ_global -= nbytes
840+ # FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
841+ self ._network_occ -= min (nbytes , self ._network_occ )
842+ self .scheduler ._network_occ_global -= min (
843+ nbytes , self .scheduler ._network_occ_global
844+ )
839845 ts .who_has .add (self )
840846 self .nbytes += nbytes
841847 self ._has_what [ts ] = None
@@ -1708,7 +1714,7 @@ class SchedulerState:
17081714 transition_counter_max : int | Literal [False ]
17091715
17101716 _task_prefix_count_global : defaultdict [str , int ]
1711- _network_occ_global : float
1717+ _network_occ_global : int
17121718 ######################
17131719 # Cached configuration
17141720 ######################
@@ -1777,7 +1783,7 @@ def __init__(
17771783 self .validate = validate
17781784 self .workers = workers
17791785 self ._task_prefix_count_global = defaultdict (int )
1780- self ._network_occ_global = 0.0
1786+ self ._network_occ_global = 0
17811787 self .running = {
17821788 ws for ws in self .workers .values () if ws .status == Status .running
17831789 }
@@ -1957,7 +1963,9 @@ def _calc_occupancy(
19571963 duration = self ._get_prefix_duration (self .task_prefixes [prefix_name ])
19581964 res += duration * count
19591965 occ = res + network_occ / self .bandwidth
1960- assert occ >= 0 , (occ , res , network_occ , self .bandwidth )
1966+ if self .validate :
1967+ assert occ >= 0 , (occ , res , network_occ , self .bandwidth )
1968+ occ = max (occ , 0 )
19611969 return occ
19621970
19631971 #####################
0 commit comments