Skip to content

Commit 10cf945

Browse files
authored
Task state domain on the scheduler side (#6929)
1 parent 54453c5 commit 10cf945

File tree

4 files changed

+88
-35
lines changed

4 files changed

+88
-35
lines changed

distributed/diagnostics/plugin.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dask.utils import funcname, tmpfile
1414

1515
if TYPE_CHECKING:
16-
from distributed.scheduler import Scheduler # circular import
16+
from distributed.scheduler import Scheduler, TaskStateState # circular imports
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -84,7 +84,12 @@ def restart(self, scheduler: Scheduler) -> None:
8484
"""Run when the scheduler restarts itself"""
8585

8686
def transition(
87-
self, key: str, start: str, finish: str, *args: Any, **kwargs: Any
87+
self,
88+
key: str,
89+
start: TaskStateState,
90+
finish: TaskStateState,
91+
*args: Any,
92+
**kwargs: Any,
8893
) -> None:
8994
"""Run whenever a task changes state
9095

distributed/http/scheduler/prometheus/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def collect(self):
6969
)
7070

7171
for state in ALL_TASK_STATES:
72-
tasks.add_metric([state], task_counter.get(state, 0.0))
72+
if state != "forgotten":
73+
tasks.add_metric([state], task_counter.get(state, 0.0))
7374
yield tasks
7475

7576

distributed/scheduler.py

Lines changed: 77 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from contextlib import suppress
3232
from functools import partial
3333
from numbers import Number
34-
from typing import Any, ClassVar, Literal, cast
34+
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
3535

3636
import psutil
3737
from sortedcontainers import SortedDict, SortedSet
@@ -106,6 +106,39 @@
106106
from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis
107107
from distributed.variable import VariableExtension
108108

109+
if TYPE_CHECKING:
110+
# TODO import from typing (requires Python >=3.10)
111+
from typing_extensions import TypeAlias
112+
113+
# TODO move out of TYPE_CHECKING (requires Python >=3.10)
114+
# Not to be confused with distributed.worker_state_machine.TaskStateState
115+
TaskStateState: TypeAlias = Literal[
116+
"released",
117+
"waiting",
118+
"no-worker",
119+
"processing",
120+
"memory",
121+
"erred",
122+
"forgotten",
123+
]
124+
125+
# TODO remove quotes (requires Python >=3.9)
126+
# {task key -> finish state}
127+
# Not to be confused with distributed.worker_state_machine.Recs
128+
Recs: TypeAlias = "dict[str, TaskStateState]"
129+
else:
130+
TaskStateState = str
131+
132+
ALL_TASK_STATES: set[TaskStateState] = {
133+
"released",
134+
"waiting",
135+
"no-worker",
136+
"processing",
137+
"memory",
138+
"erred",
139+
"forgotten",
140+
}
141+
109142
logger = logging.getLogger(__name__)
110143
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
111144
DEFAULT_DATA_SIZE = parse_bytes(
@@ -129,8 +162,6 @@
129162
"stealing": WorkStealing,
130163
}
131164

132-
ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"}
133-
134165

135166
class ClientState:
136167
"""A simple object holding information about a client."""
@@ -799,7 +830,7 @@ class TaskGroup:
799830

800831
#: The number of tasks in each state,
801832
#: like ``{"memory": 10, "processing": 3, "released": 4, ...}``
802-
states: dict[str, int]
833+
states: dict[TaskStateState, int]
803834

804835
#: The other TaskGroups on which this one depends
805836
dependencies: set[TaskGroup]
@@ -831,8 +862,7 @@ class TaskGroup:
831862
def __init__(self, name: str):
832863
self.name = name
833864
self.prefix = None
834-
self.states = {state: 0 for state in ALL_TASK_STATES}
835-
self.states["forgotten"] = 0
865+
self.states = dict.fromkeys(ALL_TASK_STATES, 0)
836866
self.dependencies = set()
837867
self.nbytes_total = 0
838868
self.duration = 0
@@ -922,7 +952,7 @@ class TaskState:
922952
priority: tuple[int, ...]
923953

924954
# Attribute underlying the state property
925-
_state: str
955+
_state: TaskStateState
926956

927957
#: The set of tasks this task depends on for proper execution. Only tasks still
928958
#: alive are listed in this set. If, for whatever reason, this task also depends on
@@ -1093,11 +1123,11 @@ class TaskState:
10931123
# Instances not part of slots since class variable
10941124
_instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet()
10951125

1096-
def __init__(self, key: str, run_spec: object):
1126+
def __init__(self, key: str, run_spec: object, state: TaskStateState):
10971127
self.key = key
10981128
self._hash = hash(key)
10991129
self.run_spec = run_spec
1100-
self._state = None # type: ignore
1130+
self._state = state
11011131
self.exception = None
11021132
self.exception_blame = None
11031133
self.traceback = None
@@ -1136,16 +1166,16 @@ def __eq__(self, other: object) -> bool:
11361166
return isinstance(other, TaskState) and self.key == other.key
11371167

11381168
@property
1139-
def state(self) -> str:
1140-
"""This task's current state. Valid states include ``released``, ``waiting``,
1169+
def state(self) -> TaskStateState:
1170+
"""This task's current state. Valid states are ``released``, ``waiting``,
11411171
``no-worker``, ``processing``, ``memory``, ``erred`` and ``forgotten``. If it
11421172
is ``forgotten``, the task isn't stored in the ``tasks`` dictionary anymore and
11431173
will probably disappear soon from memory.
11441174
"""
11451175
return self._state
11461176

11471177
@state.setter
1148-
def state(self, value: str) -> None:
1178+
def state(self, value: TaskStateState) -> None:
11491179
self.group.states[self._state] -= 1
11501180
self.group.states[value] += 1
11511181
self._state = value
@@ -1405,11 +1435,14 @@ def __pdict__(self):
14051435
}
14061436

14071437
def new_task(
1408-
self, key: str, spec: object, state: str, computation: Computation | None = None
1438+
self,
1439+
key: str,
1440+
spec: object,
1441+
state: TaskStateState,
1442+
computation: Computation | None = None,
14091443
) -> TaskState:
14101444
"""Create a new task, and associated states"""
1411-
ts = TaskState(key, spec)
1412-
ts._state = state
1445+
ts = TaskState(key, spec, state)
14131446

14141447
prefix_key = key_split(key)
14151448
tp = self.task_prefixes.get(prefix_key)
@@ -1451,8 +1484,8 @@ def _clear_task_state(self):
14511484
#####################
14521485

14531486
def _transition(
1454-
self, key: str, finish: str, stimulus_id: str, *args, **kwargs
1455-
) -> tuple[dict, dict, dict]:
1487+
self, key: str, finish: TaskStateState, stimulus_id: str, *args, **kwargs
1488+
) -> tuple[Recs, dict, dict]:
14561489
"""Transition a key from its current state to the finish state
14571490
14581491
Examples
@@ -1576,15 +1609,10 @@ def _transition(
15761609
if ts.state == "forgotten":
15771610
del self.tasks[ts.key]
15781611

1579-
tg: TaskGroup = ts.group
1612+
tg = ts.group
15801613
if ts.state == "forgotten" and tg.name in self.task_groups:
15811614
# Remove TaskGroup if all tasks are in the forgotten state
1582-
all_forgotten: bool = True
1583-
for s in ALL_TASK_STATES:
1584-
if tg.states.get(s):
1585-
all_forgotten = False
1586-
break
1587-
if all_forgotten:
1615+
if all(v == 0 or k == "forgotten" for k, v in tg.states.items()):
15881616
ts.prefix.groups.remove(tg)
15891617
del self.task_groups[tg.name]
15901618

@@ -1599,17 +1627,17 @@ def _transition(
15991627

16001628
def _transitions(
16011629
self,
1602-
recommendations: dict,
1630+
recommendations: Recs,
16031631
client_msgs: dict,
16041632
worker_msgs: dict,
16051633
stimulus_id: str,
1606-
):
1634+
) -> None:
16071635
"""Process transitions until none are left
16081636
16091637
This includes feedback from previous transitions and continues until we
16101638
reach a steady state
16111639
"""
1612-
keys: set = set()
1640+
keys: set[str] = set()
16131641
recommendations = recommendations.copy()
16141642

16151643
while recommendations:
@@ -2565,7 +2593,10 @@ def transition_released_forgotten(self, key, stimulus_id):
25652593
# ) -> (recommendations, client_msgs, worker_msgs)
25662594
# }
25672595
_TRANSITIONS_TABLE: ClassVar[
2568-
Mapping[tuple[str, str], Callable[..., tuple[dict, dict, dict]]]
2596+
Mapping[
2597+
tuple[TaskStateState, TaskStateState],
2598+
Callable[..., tuple[Recs, dict, dict]],
2599+
]
25692600
] = {
25702601
("released", "waiting"): transition_released_waiting,
25712602
("waiting", "released"): transition_waiting_released,
@@ -6719,7 +6750,14 @@ async def unregister_nanny_plugin(self, comm, name):
67196750
)
67206751
return responses
67216752

6722-
def transition(self, key, finish: str, *args, stimulus_id: str, **kwargs):
6753+
def transition(
6754+
self,
6755+
key: str,
6756+
finish: TaskStateState,
6757+
*args: Any,
6758+
stimulus_id: str,
6759+
**kwargs: Any,
6760+
) -> Recs:
67236761
"""Transition a key from its current state to the finish state
67246762
67256763
Examples
@@ -7623,7 +7661,7 @@ def decide_worker(
76237661

76247662
def validate_task_state(ts: TaskState) -> None:
76257663
"""Validate the given TaskState"""
7626-
assert ts.state in ALL_TASK_STATES or ts.state == "forgotten", ts
7664+
assert ts.state in ALL_TASK_STATES, ts
76277665

76287666
if ts.waiting_on:
76297667
assert ts.waiting_on.issubset(ts.dependencies), (
@@ -7837,8 +7875,15 @@ def update_graph(
78377875
) -> None:
78387876
self.keys.update(keys)
78397877

7840-
def transition(self, key: str, start: str, finish: str, *args, **kwargs) -> None:
7841-
if finish == "memory" or finish == "erred":
7878+
def transition(
7879+
self,
7880+
key: str,
7881+
start: TaskStateState,
7882+
finish: TaskStateState,
7883+
*args: Any,
7884+
**kwargs: Any,
7885+
) -> None:
7886+
if finish in ("memory", "erred"):
78427887
ts = self.scheduler.tasks.get(key)
78437888
if ts is not None and ts.key in self.keys:
78447889
self.metadata[key] = ts.metadata

distributed/worker_state_machine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from distributed.worker import Worker
5151

5252
# TODO move out of TYPE_CHECKING (requires Python >=3.10)
53+
# Not to be confused with distributed.scheduler.TaskStateState
5354
TaskStateState: TypeAlias = Literal[
5455
"cancelled",
5556
"constrained",
@@ -1014,6 +1015,7 @@ class SecedeEvent(StateMachineEvent):
10141015
# TODO remove quotes (requires Python >=3.9)
10151016
# TODO get out of TYPE_CHECKING (requires Python >=3.10)
10161017
# {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)}
1018+
# Not to be confused with distributed.scheduler.Recs
10171019
Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]"
10181020
Instructions: TypeAlias = "list[Instruction]"
10191021
RecsInstrs: TypeAlias = "tuple[Recs, Instructions]"

0 commit comments

Comments
 (0)