3131from contextlib import suppress
3232from functools import partial
3333from numbers import Number
34- from typing import Any , ClassVar , Literal , cast
34+ from typing import TYPE_CHECKING , Any , ClassVar , Literal , cast
3535
3636import psutil
3737from sortedcontainers import SortedDict , SortedSet
106106from distributed .utils_perf import disable_gc_diagnosis , enable_gc_diagnosis
107107from distributed .variable import VariableExtension
108108
109+ if TYPE_CHECKING :
110+ # TODO import from typing (requires Python >=3.10)
111+ from typing_extensions import TypeAlias
112+
113+ # TODO move out of TYPE_CHECKING (requires Python >=3.10)
114+ # Not to be confused with distributed.worker_state_machine.TaskStateState
115+ TaskStateState : TypeAlias = Literal [
116+ "released" ,
117+ "waiting" ,
118+ "no-worker" ,
119+ "processing" ,
120+ "memory" ,
121+ "erred" ,
122+ "forgotten" ,
123+ ]
124+
125+ # TODO remove quotes (requires Python >=3.9)
126+ # {task key -> finish state}
127+ # Not to be confused with distributed.worker_state_machine.Recs
128+ Recs : TypeAlias = "dict[str, TaskStateState]"
129+ else :
130+ TaskStateState = str
131+
132+ ALL_TASK_STATES : set [TaskStateState ] = {
133+ "released" ,
134+ "waiting" ,
135+ "no-worker" ,
136+ "processing" ,
137+ "memory" ,
138+ "erred" ,
139+ "forgotten" ,
140+ }
141+
109142logger = logging .getLogger (__name__ )
110143LOG_PDB = dask .config .get ("distributed.admin.pdb-on-err" )
111144DEFAULT_DATA_SIZE = parse_bytes (
129162 "stealing" : WorkStealing ,
130163}
131164
132- ALL_TASK_STATES = {"released" , "waiting" , "no-worker" , "processing" , "erred" , "memory" }
133-
134165
135166class ClientState :
136167 """A simple object holding information about a client."""
@@ -799,7 +830,7 @@ class TaskGroup:
799830
800831 #: The number of tasks in each state,
801832 #: like ``{"memory": 10, "processing": 3, "released": 4, ...}``
802- states : dict [str , int ]
833+ states : dict [TaskStateState , int ]
803834
804835 #: The other TaskGroups on which this one depends
805836 dependencies : set [TaskGroup ]
@@ -831,8 +862,7 @@ class TaskGroup:
831862 def __init__ (self , name : str ):
832863 self .name = name
833864 self .prefix = None
834- self .states = {state : 0 for state in ALL_TASK_STATES }
835- self .states ["forgotten" ] = 0
865+ self .states = dict .fromkeys (ALL_TASK_STATES , 0 )
836866 self .dependencies = set ()
837867 self .nbytes_total = 0
838868 self .duration = 0
@@ -922,7 +952,7 @@ class TaskState:
922952 priority : tuple [int , ...]
923953
924954 # Attribute underlying the state property
925- _state : str
955+ _state : TaskStateState
926956
927957 #: The set of tasks this task depends on for proper execution. Only tasks still
928958 #: alive are listed in this set. If, for whatever reason, this task also depends on
@@ -1093,11 +1123,11 @@ class TaskState:
10931123 # Instances not part of slots since class variable
10941124 _instances : ClassVar [weakref .WeakSet [TaskState ]] = weakref .WeakSet ()
10951125
1096- def __init__ (self , key : str , run_spec : object ):
1126+ def __init__ (self , key : str , run_spec : object , state : TaskStateState ):
10971127 self .key = key
10981128 self ._hash = hash (key )
10991129 self .run_spec = run_spec
1100- self ._state = None # type: ignore
1130+ self ._state = state
11011131 self .exception = None
11021132 self .exception_blame = None
11031133 self .traceback = None
@@ -1136,16 +1166,16 @@ def __eq__(self, other: object) -> bool:
11361166 return isinstance (other , TaskState ) and self .key == other .key
11371167
11381168 @property
1139- def state (self ) -> str :
1140- """This task's current state. Valid states include ``released``, ``waiting``,
1169+ def state (self ) -> TaskStateState :
1170+ """This task's current state. Valid states are ``released``, ``waiting``,
11411171 ``no-worker``, ``processing``, ``memory``, ``erred`` and ``forgotten``. If it
11421172 is ``forgotten``, the task isn't stored in the ``tasks`` dictionary anymore and
11431173 will probably disappear soon from memory.
11441174 """
11451175 return self ._state
11461176
11471177 @state .setter
1148- def state (self , value : str ) -> None :
1178+ def state (self , value : TaskStateState ) -> None :
11491179 self .group .states [self ._state ] -= 1
11501180 self .group .states [value ] += 1
11511181 self ._state = value
@@ -1405,11 +1435,14 @@ def __pdict__(self):
14051435 }
14061436
14071437 def new_task (
1408- self , key : str , spec : object , state : str , computation : Computation | None = None
1438+ self ,
1439+ key : str ,
1440+ spec : object ,
1441+ state : TaskStateState ,
1442+ computation : Computation | None = None ,
14091443 ) -> TaskState :
14101444 """Create a new task, and associated states"""
1411- ts = TaskState (key , spec )
1412- ts ._state = state
1445+ ts = TaskState (key , spec , state )
14131446
14141447 prefix_key = key_split (key )
14151448 tp = self .task_prefixes .get (prefix_key )
@@ -1451,8 +1484,8 @@ def _clear_task_state(self):
14511484 #####################
14521485
14531486 def _transition (
1454- self , key : str , finish : str , stimulus_id : str , * args , ** kwargs
1455- ) -> tuple [dict , dict , dict ]:
1487+ self , key : str , finish : TaskStateState , stimulus_id : str , * args , ** kwargs
1488+ ) -> tuple [Recs , dict , dict ]:
14561489 """Transition a key from its current state to the finish state
14571490
14581491 Examples
@@ -1576,15 +1609,10 @@ def _transition(
15761609 if ts .state == "forgotten" :
15771610 del self .tasks [ts .key ]
15781611
1579- tg : TaskGroup = ts .group
1612+ tg = ts .group
15801613 if ts .state == "forgotten" and tg .name in self .task_groups :
15811614 # Remove TaskGroup if all tasks are in the forgotten state
1582- all_forgotten : bool = True
1583- for s in ALL_TASK_STATES :
1584- if tg .states .get (s ):
1585- all_forgotten = False
1586- break
1587- if all_forgotten :
1615+ if all (v == 0 or k == "forgotten" for k , v in tg .states .items ()):
15881616 ts .prefix .groups .remove (tg )
15891617 del self .task_groups [tg .name ]
15901618
@@ -1599,17 +1627,17 @@ def _transition(
15991627
16001628 def _transitions (
16011629 self ,
1602- recommendations : dict ,
1630+ recommendations : Recs ,
16031631 client_msgs : dict ,
16041632 worker_msgs : dict ,
16051633 stimulus_id : str ,
1606- ):
1634+ ) -> None :
16071635 """Process transitions until none are left
16081636
16091637 This includes feedback from previous transitions and continues until we
16101638 reach a steady state
16111639 """
1612- keys : set = set ()
1640+ keys : set [ str ] = set ()
16131641 recommendations = recommendations .copy ()
16141642
16151643 while recommendations :
@@ -2565,7 +2593,10 @@ def transition_released_forgotten(self, key, stimulus_id):
25652593 # ) -> (recommendations, client_msgs, worker_msgs)
25662594 # }
25672595 _TRANSITIONS_TABLE : ClassVar [
2568- Mapping [tuple [str , str ], Callable [..., tuple [dict , dict , dict ]]]
2596+ Mapping [
2597+ tuple [TaskStateState , TaskStateState ],
2598+ Callable [..., tuple [Recs , dict , dict ]],
2599+ ]
25692600 ] = {
25702601 ("released" , "waiting" ): transition_released_waiting ,
25712602 ("waiting" , "released" ): transition_waiting_released ,
@@ -6719,7 +6750,14 @@ async def unregister_nanny_plugin(self, comm, name):
67196750 )
67206751 return responses
67216752
6722- def transition (self , key , finish : str , * args , stimulus_id : str , ** kwargs ):
6753+ def transition (
6754+ self ,
6755+ key : str ,
6756+ finish : TaskStateState ,
6757+ * args : Any ,
6758+ stimulus_id : str ,
6759+ ** kwargs : Any ,
6760+ ) -> Recs :
67236761 """Transition a key from its current state to the finish state
67246762
67256763 Examples
@@ -7623,7 +7661,7 @@ def decide_worker(
76237661
76247662def validate_task_state (ts : TaskState ) -> None :
76257663 """Validate the given TaskState"""
7626- assert ts .state in ALL_TASK_STATES or ts . state == "forgotten" , ts
7664+ assert ts .state in ALL_TASK_STATES , ts
76277665
76287666 if ts .waiting_on :
76297667 assert ts .waiting_on .issubset (ts .dependencies ), (
@@ -7837,8 +7875,15 @@ def update_graph(
78377875 ) -> None :
78387876 self .keys .update (keys )
78397877
7840- def transition (self , key : str , start : str , finish : str , * args , ** kwargs ) -> None :
7841- if finish == "memory" or finish == "erred" :
7878+ def transition (
7879+ self ,
7880+ key : str ,
7881+ start : TaskStateState ,
7882+ finish : TaskStateState ,
7883+ * args : Any ,
7884+ ** kwargs : Any ,
7885+ ) -> None :
7886+ if finish in ("memory" , "erred" ):
78427887 ts = self .scheduler .tasks .get (key )
78437888 if ts is not None and ts .key in self .keys :
78447889 self .metadata [key ] = ts .metadata
0 commit comments