Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class WorkerState:
# Reference to scheduler task_groups
scheduler_ref: weakref.ref[SchedulerState] | None
task_prefix_count: defaultdict[str, int]
_network_occ: float
_network_occ: int
_occupancy_cache: float | None

#: Keys that may need to be fetched to this worker, and the number of tasks that need them.
Expand Down Expand Up @@ -823,8 +823,11 @@ def _dec_needs_replica(self, ts: TaskState) -> None:
if self.needs_what[ts] == 0:
del self.needs_what[ts]
nbytes = ts.get_nbytes()
self._network_occ -= nbytes
self.scheduler._network_occ_global -= nbytes
# FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
self._network_occ -= min(nbytes, self._network_occ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is network_occ an integer or are we dealing with floating point foo at this point?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like _network_occ_global is defined as a float. I guess this won't be a drama since we only add and subtract integers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've adjusted types and initial values to match the assumption of integer values.

self.scheduler._network_occ_global -= min(
nbytes, self.scheduler._network_occ_global
)

def add_replica(self, ts: TaskState) -> None:
"""The worker acquired a replica of task"""
Expand All @@ -835,8 +838,11 @@ def add_replica(self, ts: TaskState) -> None:
nbytes = ts.get_nbytes()
if ts in self.needs_what:
del self.needs_what[ts]
self._network_occ -= nbytes
self.scheduler._network_occ_global -= nbytes
# FIXME: ts.get_nbytes may change if non-deterministic tasks get recomputed, causing drift
self._network_occ -= min(nbytes, self._network_occ)
self.scheduler._network_occ_global -= min(
nbytes, self.scheduler._network_occ_global
)
ts.who_has.add(self)
self.nbytes += nbytes
self._has_what[ts] = None
Expand Down Expand Up @@ -1709,7 +1715,7 @@ class SchedulerState:
transition_counter_max: int | Literal[False]

_task_prefix_count_global: defaultdict[str, int]
_network_occ_global: float
_network_occ_global: int
######################
# Cached configuration
######################
Expand Down Expand Up @@ -1778,7 +1784,7 @@ def __init__(
self.validate = validate
self.workers = workers
self._task_prefix_count_global = defaultdict(int)
self._network_occ_global = 0.0
self._network_occ_global = 0
self.running = {
ws for ws in self.workers.values() if ws.status == Status.running
}
Expand Down Expand Up @@ -1958,7 +1964,9 @@ def _calc_occupancy(
duration = self._get_prefix_duration(self.task_prefixes[prefix_name])
res += duration * count
occ = res + network_occ / self.bandwidth
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
if self.validate:
assert occ >= 0, (occ, res, network_occ, self.bandwidth)
occ = max(occ, 0)
return occ

#####################
Expand Down
Loading