@@ -268,10 +268,40 @@ def get_task_url(task_id):
268268 return task_tmpl .format (task_id )
269269
270270
271- @functools .cache
272- def get_task_definition (task_id ):
273- queue = get_taskcluster_client ("queue" )
274- return queue .task (task_id )
271+ class TaskDefinitionsCache :
272+ def __init__ (self ):
273+ self .cache = {}
274+
275+ def get_task_definition (self , task_id ):
276+ if task_id not in self .cache :
277+ queue = get_taskcluster_client ("queue" )
278+ self .cache [task_id ] = queue .task (task_id )
279+ return self .cache [task_id ]
280+
281+ def get_task_definitions (self , task_ids ):
282+ missing_task_ids = list (set (task_ids ) - set (self .cache ))
283+ if missing_task_ids :
284+ queue = get_taskcluster_client ("queue" )
285+
286+ def pagination_handler (response ):
287+ self .cache .update (
288+ {task ["taskId" ]: task ["task" ] for task in response ["tasks" ]}
289+ )
290+
291+ queue .tasks (
292+ payload = {"taskIds" : missing_task_ids },
293+ paginationHandler = pagination_handler ,
294+ )
295+ return {
296+ task_id : self .cache [task_id ]
297+ for task_id in task_ids
298+ if task_id in self .cache
299+ }
300+
301+
302+ _task_definitions_cache = TaskDefinitionsCache ()
303+ get_task_definition = _task_definitions_cache .get_task_definition
304+ get_task_definitions = _task_definitions_cache .get_task_definitions
275305
276306
277307def cancel_task (task_id ):
@@ -430,22 +460,20 @@ def pagination_handler(response):
430460 return incomplete_tasks
431461
432462
433- @functools .cache
434463def _get_deps (task_ids ):
435464 upstream_tasks = {}
436- for task_id in task_ids :
437- task_def = get_task_definition (task_id )
438- if not task_def :
439- continue
440-
465+ task_defs = get_task_definitions (task_ids )
466+ dependencies = set ()
467+ for task_id , task_def in task_defs .items ():
441468 metadata = task_def .get ("metadata" , {}) # type: ignore
442469 name = metadata .get ("name" ) # type: ignore
443470 if name :
444471 upstream_tasks [task_id ] = name
445472
446- dependencies = task_def .get ("dependencies" , [])
447- if dependencies :
448- upstream_tasks .update (_get_deps (tuple (dependencies )))
473+ dependencies |= set (task_def .get ("dependencies" , []))
474+
475+ if dependencies :
476+ upstream_tasks .update (_get_deps (tuple (dependencies )))
449477
450478 return upstream_tasks
451479
@@ -464,18 +492,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]:
464492 if isinstance (task_ids , str ):
465493 task_ids = [task_ids ]
466494
467- for task_id in task_ids :
468- try :
469- task_def = get_task_definition (task_id )
470- except taskcluster .TaskclusterRestFailure as e :
471- # Task has most likely expired, which means it's no longer a
472- # dependency for the purposes of this function.
473- if e .status_code == 404 :
474- continue
475- raise e
476-
477- dependencies = task_def .get ("dependencies" , [])
478- if dependencies :
479- upstream_tasks .update (_get_deps (tuple (dependencies )))
495+ task_defs = get_task_definitions (task_ids )
496+ dependencies = set ()
497+
498+ for task_id , task_def in task_defs .items ():
499+ dependencies |= set (task_def .get ("dependencies" , []))
500+ if dependencies :
501+ upstream_tasks .update (_get_deps (tuple (dependencies )))
480502
481503 return copy .deepcopy (upstream_tasks )
0 commit comments