Skip to content

Commit 525f0c2

Browse files
committed
Batch task definition lookups in get_ancestors
1 parent 37d045e commit 525f0c2

File tree

2 files changed

+63
-39
lines changed

2 files changed

+63
-39
lines changed

src/taskgraph/util/taskcluster.py

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,40 @@ def get_task_url(task_id):
268268
return task_tmpl.format(task_id)
269269

270270

271-
@functools.cache
272-
def get_task_definition(task_id):
273-
queue = get_taskcluster_client("queue")
274-
return queue.task(task_id)
271+
class TaskDefinitionsCache:
272+
def __init__(self):
273+
self.cache = {}
274+
275+
def get_task_definition(self, task_id):
276+
if task_id not in self.cache:
277+
queue = get_taskcluster_client("queue")
278+
self.cache[task_id] = queue.task(task_id)
279+
return self.cache[task_id]
280+
281+
def get_task_definitions(self, task_ids):
282+
missing_task_ids = list(set(task_ids) - set(self.cache))
283+
if missing_task_ids:
284+
queue = get_taskcluster_client("queue")
285+
286+
def pagination_handler(response):
287+
self.cache.update(
288+
{task["taskId"]: task["task"] for task in response["tasks"]}
289+
)
290+
291+
queue.tasks(
292+
payload={"taskIds": missing_task_ids},
293+
paginationHandler=pagination_handler,
294+
)
295+
return {
296+
task_id: self.cache[task_id]
297+
for task_id in task_ids
298+
if task_id in self.cache
299+
}
300+
301+
302+
_task_definitions_cache = TaskDefinitionsCache()
303+
get_task_definition = _task_definitions_cache.get_task_definition
304+
get_task_definitions = _task_definitions_cache.get_task_definitions
275305

276306

277307
def cancel_task(task_id):
@@ -433,19 +463,18 @@ def pagination_handler(response):
433463
@functools.cache
434464
def _get_deps(task_ids):
435465
upstream_tasks = {}
436-
for task_id in task_ids:
437-
task_def = get_task_definition(task_id)
438-
if not task_def:
439-
continue
440-
466+
task_defs = get_task_definitions(task_ids)
467+
dependencies = set()
468+
for task_id, task_def in task_defs.items():
441469
metadata = task_def.get("metadata", {}) # type: ignore
442470
name = metadata.get("name") # type: ignore
443471
if name:
444472
upstream_tasks[task_id] = name
445473

446-
dependencies = task_def.get("dependencies", [])
447-
if dependencies:
448-
upstream_tasks.update(_get_deps(tuple(dependencies)))
474+
dependencies |= set(task_def.get("dependencies", []))
475+
476+
if dependencies:
477+
upstream_tasks.update(_get_deps(tuple(dependencies)))
449478

450479
return upstream_tasks
451480

@@ -464,18 +493,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]:
464493
if isinstance(task_ids, str):
465494
task_ids = [task_ids]
466495

467-
for task_id in task_ids:
468-
try:
469-
task_def = get_task_definition(task_id)
470-
except taskcluster.TaskclusterRestFailure as e:
471-
# Task has most likely expired, which means it's no longer a
472-
# dependency for the purposes of this function.
473-
if e.status_code == 404:
474-
continue
475-
raise e
476-
477-
dependencies = task_def.get("dependencies", [])
478-
if dependencies:
479-
upstream_tasks.update(_get_deps(tuple(dependencies)))
496+
task_defs = get_task_definitions(task_ids)
497+
dependencies = set()
498+
499+
for task_id, task_def in task_defs.items():
500+
dependencies |= set(task_def.get("dependencies", []))
501+
if dependencies:
502+
upstream_tasks.update(_get_deps(tuple(dependencies)))
480503

481504
return copy.deepcopy(upstream_tasks)

test/test_util_taskcluster.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import datetime
6+
import json
67
from unittest.mock import MagicMock
78

89
import pytest
@@ -487,7 +488,6 @@ def test_list_task_group_incomplete_tasks(responses, root_url):
487488

488489

489490
def test_get_ancestors(responses, root_url):
490-
tc.get_task_definition.cache_clear()
491491
tc._get_deps.cache_clear()
492492
tc.get_taskcluster_client.cache_clear()
493493

@@ -518,12 +518,13 @@ def test_get_ancestors(responses, root_url):
518518
},
519519
}
520520

521-
# Mock API responses for each task definition
522-
for task_id, definition in task_definitions.items():
523-
responses.get(
524-
f"{root_url}/api/queue/v1/task/{task_id}",
525-
json=definition,
526-
)
521+
# Mock API response for task definitions
522+
def tasks_callback(request):
523+
payload = json.loads(request.body)
524+
resp_body = {"tasks": [{"taskId": task_id, "task": task_definitions[task_id]} for task_id in payload["taskIds"]]}
525+
return (200, [], json.dumps(resp_body))
526+
527+
responses.add_callback(responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback)
527528

528529
got = tc.get_ancestors(["bbb", "fff"])
529530
expected = {
@@ -536,7 +537,6 @@ def test_get_ancestors(responses, root_url):
536537

537538

538539
def test_get_ancestors_string(responses, root_url):
539-
tc.get_task_definition.cache_clear()
540540
tc._get_deps.cache_clear()
541541
tc.get_taskcluster_client.cache_clear()
542542

@@ -567,12 +567,13 @@ def test_get_ancestors_string(responses, root_url):
567567
},
568568
}
569569

570-
# Mock API responses for each task definition
571-
for task_id, definition in task_definitions.items():
572-
responses.get(
573-
f"{root_url}/api/queue/v1/task/{task_id}",
574-
json=definition,
575-
)
570+
# Mock API response for task definitions
571+
def tasks_callback(request):
572+
payload = json.loads(request.body)
573+
resp_body = {"tasks": [{"taskId": task_id, "task": task_definitions[task_id]} for task_id in payload["taskIds"]]}
574+
return (200, [], json.dumps(resp_body))
575+
576+
responses.add_callback(responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback)
576577

577578
got = tc.get_ancestors("fff")
578579
expected = {

0 commit comments

Comments
 (0)