6262from tornado .ioloop import IOLoop
6363
6464import dask
65- import dask .utils
6665from dask ._expr import LLGExpr
67- from dask ._task_spec import DependenciesMapping , GraphNode , convert_legacy_graph
66+ from dask ._task_spec import GraphNode , convert_legacy_graph
6867from dask .core import istask , validate_key
6968from dask .typing import Key , no_default
7069from dask .utils import (
@@ -4705,7 +4704,6 @@ def _create_taskstate_from_graph(
47054704 * ,
47064705 start : float ,
47074706 dsk : dict [Key , T_runspec ],
4708- dependencies : dict ,
47094707 keys : set [Key ],
47104708 ordered : dict [Key , int ],
47114709 client : str ,
@@ -4744,14 +4742,12 @@ def _create_taskstate_from_graph(
47444742 # annotations.
47454743 computation .annotations .update (global_annotations )
47464744 (
4747- runnable ,
47484745 touched_tasks ,
47494746 new_tasks ,
47504747 colliding_task_count ,
47514748 ) = self ._generate_taskstates (
47524749 keys = keys ,
47534750 dsk = dsk ,
4754- dependencies = dependencies ,
47554751 computation = computation ,
47564752 )
47574753
@@ -4773,7 +4769,7 @@ def _create_taskstate_from_graph(
47734769 user_priority = user_priority ,
47744770 fifo_timeout = fifo_timeout ,
47754771 start = start ,
4776- tasks = runnable ,
4772+ tasks = touched_tasks ,
47774773 )
47784774
47794775 self .client_desires_keys (keys = keys , client = client )
@@ -4787,19 +4783,17 @@ def _create_taskstate_from_graph(
47874783
47884784 # Compute recommendations
47894785 recommendations : Recs = {}
4790- priority = dict ()
47914786 for ts in sorted (
4792- runnable ,
4787+ filter (
4788+ lambda ts : ts .state == "released" ,
4789+ map (self .tasks .__getitem__ , keys ),
4790+ ),
47934791 key = operator .attrgetter ("priority" ),
47944792 reverse = True ,
47954793 ):
4796- assert ts .priority # mypy
4797- priority [ts .key ] = ts .priority
4798- assert ts .run_spec
4799- if ts .state == "released" :
4800- recommendations [ts .key ] = "waiting"
4794+ recommendations [ts .key ] = "waiting"
48014795
4802- for ts in runnable :
4796+ for ts in touched_tasks :
48034797 for dts in ts .dependencies :
48044798 if dts .exception_blame :
48054799 ts .exception_blame = dts .exception_blame
@@ -4820,7 +4814,7 @@ def _create_taskstate_from_graph(
48204814 # TaskState may have also been created by client_desires_keys or scatter,
48214815 # and only later gained a run_spec.
48224816 span_annotations = spans_ext .observe_tasks (
4823- runnable , span_metadata = span_metadata , code = code
4817+ touched_tasks , span_metadata = span_metadata , code = code
48244818 )
48254819 # In case of TaskGroup collision, spans may have changed
48264820 # FIXME: Is this used anywhere besides tests?
@@ -4829,16 +4823,17 @@ def _create_taskstate_from_graph(
48294823 else :
48304824 annotations_for_plugin .pop ("span" , None )
48314825
4826+ tasks_for_plugin = [ts .key for ts in touched_tasks ]
4827+ priorities_for_plugin = {ts .key : ts .priority for ts in touched_tasks }
48324828 for plugin in list (self .plugins .values ()):
48334829 try :
48344830 plugin .update_graph (
48354831 self ,
48364832 client = client ,
4837- tasks = [ ts . key for ts in touched_tasks ] ,
4833+ tasks = tasks_for_plugin ,
48384834 keys = keys ,
4839- dependencies = dependencies ,
4840- annotations = dict (annotations_for_plugin ),
4841- priority = priority ,
4835+ annotations = annotations_for_plugin ,
4836+ priority = priorities_for_plugin ,
48424837 stimulus_id = stimulus_id ,
48434838 )
48444839 except Exception as e :
@@ -4852,42 +4847,6 @@ def _create_taskstate_from_graph(
48524847
48534848 return metrics
48544849
4855- def _remove_done_tasks_from_dsk (
4856- self ,
4857- dsk : dict [Key , T_runspec ],
4858- dependencies : dict [Key , set [Key ]],
4859- ) -> None :
4860- # Avoid computation that is already finished
4861- done = set () # tasks that are already done
4862- for k , v in dependencies .items ():
4863- if v and k in self .tasks :
4864- ts = self .tasks [k ]
4865- if ts .state in ("memory" , "erred" ):
4866- done .add (k )
4867- if done :
4868- dependents = dask .core .reverse_dict (dependencies )
4869- stack = list (done )
4870- while stack : # remove unnecessary dependencies
4871- key = stack .pop ()
4872- try :
4873- deps = dependencies [key ]
4874- except KeyError :
4875- deps = {ts .key for ts in self .tasks [key ].dependencies }
4876- for dep in deps :
4877- if dep in dependents :
4878- child_deps = dependents [dep ]
4879- elif dep in self .tasks :
4880- child_deps = {ts .key for ts in self .tasks [key ].dependencies }
4881- else :
4882- child_deps = set ()
4883- if all (d in done for d in child_deps ):
4884- if dep in self .tasks and dep not in done :
4885- done .add (dep )
4886- stack .append (dep )
4887- for anc in done :
4888- dsk .pop (anc , None )
4889- dependencies .pop (anc , None )
4890-
48914850 @log_errors
48924851 async def update_graph (
48934852 self ,
@@ -4924,7 +4883,6 @@ async def update_graph(
49244883 raise RuntimeError (textwrap .dedent (msg )) from e
49254884 (
49264885 dsk ,
4927- dependencies ,
49284886 annotations_by_type ,
49294887 ) = await offload (
49304888 _materialize_graph ,
@@ -4976,12 +4934,9 @@ async def update_graph(
49764934
49774935 before = len (self .tasks )
49784936
4979- self ._remove_done_tasks_from_dsk (dsk , dependencies )
4980-
49814937 metrics = self ._create_taskstate_from_graph (
49824938 dsk = dsk ,
49834939 client = client ,
4984- dependencies = dependencies ,
49854940 keys = set (keys ),
49864941 ordered = internal_priority or {},
49874942 submitting_task = submitting_task ,
@@ -5045,17 +5000,16 @@ def _generate_taskstates(
50455000 self ,
50465001 keys : set [Key ],
50475002 dsk : dict [Key , T_runspec ],
5048- dependencies : dict [Key , set [Key ]],
50495003 computation : Computation ,
50505004 ) -> tuple :
50515005 # Get or create task states
5052- runnable = list ()
50535006 new_tasks = []
50545007 stack = list (keys )
50555008 touched_keys = set ()
50565009 touched_tasks = []
50575010 tgs_with_bad_run_spec = set ()
50585011 colliding_task_count = 0
5012+ collisions = set ()
50595013 while stack :
50605014 k = stack .pop ()
50615015 if k in touched_keys :
@@ -5078,18 +5032,13 @@ def _generate_taskstates(
50785032 elif k in dsk :
50795033 # Check dependency names.
50805034 deps_lhs = {dts .key for dts in ts .dependencies }
5081- deps_rhs = dependencies [k ]
5035+ deps_rhs = dsk [k ]. dependencies
50825036
50835037 # FIXME It would be a really healthy idea to change this to a hard
50845038 # failure. However, this is not possible at the moment because of
50855039 # https://github.com/dask/dask/issues/9888
50865040 if deps_lhs != deps_rhs :
5087- # Retain old run_spec and dependencies; rerun them if necessary.
5088- # This sweeps the issue of collision under the carpet as long as the
5089- # old and new task produce the same output - such as in
5090- # dask/dask#9888.
5091- dependencies [k ] = deps_lhs
5092-
5041+ collisions .add (k )
50935042 colliding_task_count += 1
50945043 if ts .group not in tgs_with_bad_run_spec :
50955044 tgs_with_bad_run_spec .add (ts .group )
@@ -5120,18 +5069,17 @@ def _generate_taskstates(
51205069 "two consecutive calls to `update_graph`."
51215070 )
51225071
5123- if ts .run_spec :
5124- runnable .append (ts )
51255072 touched_keys .add (k )
51265073 touched_tasks .append (ts )
5127- stack .extend (dependencies .get (k , ()))
5074+ if tspec := dsk .get (k , ()):
5075+ stack .extend (tspec .dependencies )
51285076
51295077 # Add dependencies
5130- for key , deps in dependencies .items ():
5078+ for key , tspec in dsk .items ():
51315079 ts = self .tasks .get (key )
5132- if ts is None or ts . dependencies :
5080+ if ts is None or key in collisions :
51335081 continue
5134- for dep in deps :
5082+ for dep in tspec . dependencies :
51355083 dts = self .tasks [dep ]
51365084 ts .add_dependency (dts )
51375085
@@ -5141,7 +5089,7 @@ def _generate_taskstates(
51415089 len (touched_tasks ),
51425090 len (keys ),
51435091 )
5144- return runnable , touched_tasks , new_tasks , colliding_task_count
5092+ return touched_tasks , new_tasks , colliding_task_count
51455093
51465094 def _apply_annotations (
51475095 self ,
@@ -9509,7 +9457,7 @@ def transition(
95099457def _materialize_graph (
95109458 expr : Expr ,
95119459 validate : bool ,
9512- ) -> tuple [dict [Key , T_runspec ], dict [Key , set [ Key ]], dict [ str , dict [Key , Any ]]]:
9460+ ) -> tuple [dict [Key , T_runspec ], dict [str , dict [Key , Any ]]]:
95139461 dsk : dict = expr .__dask_graph__ ()
95149462 if validate :
95159463 for k in dsk :
@@ -9520,10 +9468,7 @@ def _materialize_graph(
95209468 annotations_by_type [annotations_type ].update (value )
95219469
95229470 dsk2 = convert_legacy_graph (dsk )
9523- # FIXME: There should be no need to fully materialize and copy this but some
9524- # sections in the scheduler are mutating it.
9525- dependencies = {k : set (v ) for k , v in DependenciesMapping (dsk2 ).items ()}
9526- return dsk2 , dependencies , annotations_by_type
9471+ return dsk2 , annotations_by_type
95279472
95289473
95299474def _cull (dsk : dict [Key , GraphNode ], keys : set [Key ]) -> dict [Key , GraphNode ]:
0 commit comments