@@ -268,10 +268,40 @@ def get_task_url(task_id):
268268 return task_tmpl .format (task_id )
269269
270270
271- @functools .cache
272- def get_task_definition (task_id ):
273- queue = get_taskcluster_client ("queue" )
274- return queue .task (task_id )
271+ class TaskDefinitionsCache :
272+ def __init__ (self ):
273+ self .cache = {}
274+
275+ def get_task_definition (self , task_id ):
276+ if task_id not in self .cache :
277+ queue = get_taskcluster_client ("queue" )
278+ self .cache [task_id ] = queue .task (task_id )
279+ return self .cache [task_id ]
280+
281+ def get_task_definitions (self , task_ids ):
282+ missing_task_ids = list (set (task_ids ) - set (self .cache ))
283+ if missing_task_ids :
284+ queue = get_taskcluster_client ("queue" )
285+
286+ def pagination_handler (response ):
287+ self .cache .update (
288+ {task ["taskId" ]: task ["task" ] for task in response ["tasks" ]}
289+ )
290+
291+ queue .tasks (
292+ payload = {"taskIds" : missing_task_ids },
293+ paginationHandler = pagination_handler ,
294+ )
295+ return {
296+ task_id : self .cache [task_id ]
297+ for task_id in task_ids
298+ if task_id in self .cache
299+ }
300+
301+
302+ _task_definitions_cache = TaskDefinitionsCache ()
303+ get_task_definition = _task_definitions_cache .get_task_definition
304+ get_task_definitions = _task_definitions_cache .get_task_definitions
275305
276306
277307def cancel_task (task_id ):
@@ -433,19 +463,18 @@ def pagination_handler(response):
433463@functools .cache
434464def _get_deps (task_ids ):
435465 upstream_tasks = {}
436- for task_id in task_ids :
437- task_def = get_task_definition (task_id )
438- if not task_def :
439- continue
440-
466+ task_defs = get_task_definitions (task_ids )
467+ dependencies = set ()
468+ for task_id , task_def in task_defs .items ():
441469 metadata = task_def .get ("metadata" , {}) # type: ignore
442470 name = metadata .get ("name" ) # type: ignore
443471 if name :
444472 upstream_tasks [task_id ] = name
445473
446- dependencies = task_def .get ("dependencies" , [])
447- if dependencies :
448- upstream_tasks .update (_get_deps (tuple (dependencies )))
474+ dependencies |= set (task_def .get ("dependencies" , []))
475+
476+ if dependencies :
477+ upstream_tasks .update (_get_deps (tuple (dependencies )))
449478
450479 return upstream_tasks
451480
@@ -464,18 +493,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]:
464493 if isinstance (task_ids , str ):
465494 task_ids = [task_ids ]
466495
467- for task_id in task_ids :
468- try :
469- task_def = get_task_definition (task_id )
470- except taskcluster .TaskclusterRestFailure as e :
471- # Task has most likely expired, which means it's no longer a
472- # dependency for the purposes of this function.
473- if e .status_code == 404 :
474- continue
475- raise e
476-
477- dependencies = task_def .get ("dependencies" , [])
478- if dependencies :
479- upstream_tasks .update (_get_deps (tuple (dependencies )))
496+ task_defs = get_task_definitions (task_ids )
497+ dependencies = set ()
498+
499+ for task_id , task_def in task_defs .items ():
500+ dependencies |= set (task_def .get ("dependencies" , []))
501+ if dependencies :
502+ upstream_tasks .update (_get_deps (tuple (dependencies )))
480503
481504 return copy .deepcopy (upstream_tasks )
0 commit comments