@@ -4687,7 +4687,6 @@ def _create_taskstate_from_graph(
46874687 # FIXME: This is kind of inconsistent since it only includes global
46884688 # annotations.
46894689 computation .annotations .update (global_annotations )
4690- del global_annotations
46914690 (
46924691 runnable ,
46934692 touched_tasks ,
@@ -4709,6 +4708,7 @@ def _create_taskstate_from_graph(
47094708 keys_with_annotations = self ._apply_annotations (
47104709 tasks = new_tasks ,
47114710 annotations_by_type = annotations_by_type ,
4711+ global_annotations = global_annotations ,
47124712 )
47134713
47144714 self ._set_priorities (
@@ -4872,7 +4872,6 @@ async def update_graph(
48724872 ) = await offload (
48734873 _materialize_graph ,
48744874 expr = expr ,
4875- global_annotations = annotations or {},
48764875 validate = self .validate ,
48774876 )
48784877
@@ -4927,8 +4926,6 @@ async def update_graph(
49274926 code = code ,
49284927 span_metadata = span_metadata ,
49294928 annotations_by_type = annotations_by_type ,
4930- # FIXME: This is just used to attach to Computation
4931- # objects. This should be removed
49324929 global_annotations = annotations ,
49334930 start = start ,
49344931 stimulus_id = stimulus_id ,
@@ -5085,6 +5082,7 @@ def _apply_annotations(
50855082 self ,
50865083 tasks : Iterable [TaskState ],
50875084 annotations_by_type : dict [str , dict [Key , Any ]],
5085+ global_annotations : dict [str , Any ] | None = None ,
50885086 ) -> set [Key ]:
50895087 """Apply the provided annotations to the provided `TaskState` objects.
50905088
@@ -5104,13 +5102,18 @@ def _apply_annotations(
51045102 keys_with_annotations
51055103 """
51065104 keys_with_annotations : set [Key ] = set ()
5107- if not annotations_by_type :
5105+ if not annotations_by_type and not global_annotations :
51085106 return keys_with_annotations
51095107
51105108 for ts in tasks :
51115109 key = ts .key
51125110
51135111 ts_annotations = {}
5112+ if global_annotations :
5113+ for annot , value in global_annotations .items ():
5114+ if callable (value ):
5115+ value = value (ts .key )
5116+ ts_annotations [annot ] = value
51145117 for annot , key_value in annotations_by_type .items ():
51155118 if (value := key_value .get (key )) is not None :
51165119 ts_annotations [annot ] = value
@@ -9429,18 +9432,14 @@ def transition(
94299432
94309433def _materialize_graph (
94319434 expr : Expr ,
9432- global_annotations : dict [str , Any ],
94339435 validate : bool ,
94349436) -> tuple [dict [Key , T_runspec ], dict [Key , set [Key ]], dict [str , dict [Key , Any ]]]:
94359437 dsk : dict = expr .__dask_graph__ ()
94369438 if validate :
94379439 for k in dsk :
94389440 validate_key (k )
94399441 annotations_by_type : defaultdict [str , dict [Key , Any ]] = defaultdict (dict )
9440- for annotations_type , value in global_annotations .items ():
9441- annotations_by_type [annotations_type ].update (
9442- {k : (value (k ) if callable (value ) else value ) for k in dsk }
9443- )
9442+
94449443 for annotations_type , value in expr .__dask_annotations__ ().items ():
94459444 annotations_by_type [annotations_type ].update (value )
94469445
0 commit comments