@@ -1674,9 +1674,6 @@ class SchedulerState:
16741674 #: Subset of tasks that exist in memory on more than one worker
16751675 replicated_tasks : set [TaskState ]
16761676
1677- #: Tasks with unknown duration, grouped by prefix
1678- #: {task prefix: {ts, ts, ...}}
1679- unknown_durations : dict [str , set [TaskState ]]
16801677 task_groups : dict [str , TaskGroup ]
16811678 task_prefixes : dict [str , TaskPrefix ]
16821679 task_metadata : dict [Key , Any ]
@@ -1776,7 +1773,6 @@ def __init__(
17761773 self .task_metadata = {}
17771774 self .total_nthreads = 0
17781775 self .total_nthreads_history = [(time (), 0 )]
1779- self .unknown_durations = {}
17801776 self .queued = queued
17811777 self .unrunnable = unrunnable
17821778 self .validate = validate
@@ -1855,7 +1851,6 @@ def __pdict__(self) -> dict[str, Any]:
18551851 "unrunnable" : self .unrunnable ,
18561852 "queued" : self .queued ,
18571853 "n_tasks" : self .n_tasks ,
1858- "unknown_durations" : self .unknown_durations ,
18591854 "validate" : self .validate ,
18601855 "tasks" : self .tasks ,
18611856 "task_groups" : self .task_groups ,
@@ -1907,7 +1902,6 @@ def _clear_task_state(self) -> None:
19071902 self .task_prefixes ,
19081903 self .task_groups ,
19091904 self .task_metadata ,
1910- self .unknown_durations ,
19111905 self .replicated_tasks ,
19121906 ):
19131907 collection .clear ()
@@ -1931,22 +1925,37 @@ def total_occupancy(self) -> float:
19311925 self ._network_occ_global ,
19321926 )
19331927
1928+ def _get_prefix_duration (self , prefix : TaskPrefix ) -> float :
1929+ """Get the estimated computation cost of the given task prefix
1930+ (not including any communication cost).
1931+
1932+ If no data has been observed, value of
1933+ `distributed.scheduler.default-task-durations` are used. If none is set
1934+ for this task, `distributed.scheduler.unknown-task-duration` is used
1935+ instead.
1936+
1937+ See Also
1938+ --------
1939+ WorkStealing.get_task_duration
1940+ """
1941+ # TODO: Deal with unknown tasks better
1942+ assert prefix is not None
1943+ duration = prefix .duration_average
1944+ if duration < 0 :
1945+ if prefix .max_exec_time > 0 :
1946+ duration = 2 * prefix .max_exec_time
1947+ else :
1948+ duration = self .UNKNOWN_TASK_DURATION
1949+ return duration
1950+
19341951 def _calc_occupancy (
19351952 self ,
19361953 task_prefix_count : dict [str , int ],
19371954 network_occ : float ,
19381955 ) -> float :
19391956 res = 0.0
19401957 for prefix_name , count in task_prefix_count .items ():
1941- # TODO: Deal with unknown tasks better
1942- prefix = self .task_prefixes [prefix_name ]
1943- assert prefix is not None
1944- duration = prefix .duration_average
1945- if duration < 0 :
1946- if prefix .max_exec_time > 0 :
1947- duration = 2 * prefix .max_exec_time
1948- else :
1949- duration = self .UNKNOWN_TASK_DURATION
1958+ duration = self ._get_prefix_duration (self .task_prefixes [prefix_name ])
19501959 res += duration * count
19511960 occ = res + network_occ / self .bandwidth
19521961 assert occ >= 0 , (occ , res , network_occ , self .bandwidth )
@@ -2536,13 +2545,6 @@ def _transition_processing_memory(
25362545 action = startstop ["action" ],
25372546 )
25382547
2539- s = self .unknown_durations .pop (ts .prefix .name , set ())
2540- steal = self .extensions .get ("stealing" )
2541- if steal :
2542- for tts in s :
2543- if tts .processing_on :
2544- steal .recalculate_cost (tts )
2545-
25462548 ############################
25472549 # Update State Information #
25482550 ############################
@@ -3171,26 +3173,6 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> float:
31713173 nbytes = sum (dts .nbytes for dts in deps )
31723174 return nbytes / self .bandwidth
31733175
3174- def get_task_duration (self , ts : TaskState ) -> float :
3175- """Get the estimated computation cost of the given task (not including
3176- any communication cost).
3177-
3178- If no data has been observed, value of
3179- `distributed.scheduler.default-task-durations` are used. If none is set
3180- for this task, `distributed.scheduler.unknown-task-duration` is used
3181- instead.
3182- """
3183- prefix = ts .prefix
3184- duration : float = prefix .duration_average
3185- if duration >= 0 :
3186- return duration
3187-
3188- s = self .unknown_durations .get (prefix .name )
3189- if s is None :
3190- self .unknown_durations [prefix .name ] = s = set ()
3191- s .add (ts )
3192- return self .UNKNOWN_TASK_DURATION
3193-
31943176 def valid_workers (self , ts : TaskState ) -> set [WorkerState ] | None :
31953177 """Return set of currently valid workers for key
31963178
@@ -3569,20 +3551,15 @@ def _client_releases_keys(
35693551 elif ts .state != "erred" and not ts .waiters :
35703552 recommendations [ts .key ] = "released"
35713553
3572- def _task_to_msg (self , ts : TaskState , duration : float = - 1 ) -> dict [str , Any ]:
3554+ def _task_to_msg (self , ts : TaskState ) -> dict [str , Any ]:
35733555 """Convert a single computational task to a message"""
3574- # FIXME: The duration attribute is not used on worker. We could save ourselves the
3575- # time to compute and submit this
3576- if duration < 0 :
3577- duration = self .get_task_duration (ts )
35783556 ts .run_id = next (TaskState ._run_id_iterator )
35793557 assert ts .priority , ts
35803558 msg : dict [str , Any ] = {
35813559 "op" : "compute-task" ,
35823560 "key" : ts .key ,
35833561 "run_id" : ts .run_id ,
35843562 "priority" : ts .priority ,
3585- "duration" : duration ,
35863563 "stimulus_id" : f"compute-task-{ time ()} " ,
35873564 "who_has" : {
35883565 dts .key : tuple (ws .address for ws in (dts .who_has or ()))
@@ -6003,12 +5980,10 @@ async def remove_client_from_events() -> None:
60035980 cleanup_delay , remove_client_from_events
60045981 )
60055982
6006- def send_task_to_worker (
6007- self , worker : str , ts : TaskState , duration : float = - 1
6008- ) -> None :
5983+ def send_task_to_worker (self , worker : str , ts : TaskState ) -> None :
60095984 """Send a single computational task to a worker"""
60105985 try :
6011- msg = self ._task_to_msg (ts , duration )
5986+ msg = self ._task_to_msg (ts )
60125987 self .worker_send (worker , msg )
60135988 except Exception as e :
60145989 logger .exception (e )
@@ -8859,10 +8834,7 @@ def adaptive_target(self, target_duration=None):
88598834 queued = take (100 , concat ([self .queued , self .unrunnable .keys ()]))
88608835 queued_occupancy = 0
88618836 for ts in queued :
8862- if ts .prefix .duration_average == - 1 :
8863- queued_occupancy += self .UNKNOWN_TASK_DURATION
8864- else :
8865- queued_occupancy += ts .prefix .duration_average
8837+ queued_occupancy += self ._get_prefix_duration (ts .prefix )
88668838
88678839 tasks_ready = len (self .queued ) + len (self .unrunnable )
88688840 if tasks_ready > 100 :
0 commit comments