diff --git a/src/taskgraph/util/taskcluster.py b/src/taskgraph/util/taskcluster.py index b8cf8d56..e313c5bd 100644 --- a/src/taskgraph/util/taskcluster.py +++ b/src/taskgraph/util/taskcluster.py @@ -268,10 +268,40 @@ def get_task_url(task_id): return task_tmpl.format(task_id) -@functools.cache -def get_task_definition(task_id): - queue = get_taskcluster_client("queue") - return queue.task(task_id) +class TaskDefinitionsCache: + def __init__(self): + self.cache = {} + + def get_task_definition(self, task_id): + if task_id not in self.cache: + queue = get_taskcluster_client("queue") + self.cache[task_id] = queue.task(task_id) + return self.cache[task_id] + + def get_task_definitions(self, task_ids): + missing_task_ids = list(set(task_ids) - set(self.cache)) + if missing_task_ids: + queue = get_taskcluster_client("queue") + + def pagination_handler(response): + self.cache.update( + {task["taskId"]: task["task"] for task in response["tasks"]} + ) + + queue.tasks( + payload={"taskIds": missing_task_ids}, + paginationHandler=pagination_handler, + ) + return { + task_id: self.cache[task_id] + for task_id in task_ids + if task_id in self.cache + } + + +_task_definitions_cache = TaskDefinitionsCache() +get_task_definition = _task_definitions_cache.get_task_definition +get_task_definitions = _task_definitions_cache.get_task_definitions def cancel_task(task_id): @@ -430,22 +460,20 @@ def pagination_handler(response): return incomplete_tasks -@functools.cache def _get_deps(task_ids): upstream_tasks = {} - for task_id in task_ids: - task_def = get_task_definition(task_id) - if not task_def: - continue - + task_defs = get_task_definitions(task_ids) + dependencies = set() + for task_id, task_def in task_defs.items(): metadata = task_def.get("metadata", {}) # type: ignore name = metadata.get("name") # type: ignore if name: upstream_tasks[task_id] = name - dependencies = task_def.get("dependencies", []) - if dependencies: - upstream_tasks.update(_get_deps(tuple(dependencies))) + dependencies |= set(task_def.get("dependencies", [])) + + if dependencies: + upstream_tasks.update(_get_deps(tuple(dependencies))) return upstream_tasks @@ -464,18 +492,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]: if isinstance(task_ids, str): task_ids = [task_ids] - for task_id in task_ids: - try: - task_def = get_task_definition(task_id) - except taskcluster.TaskclusterRestFailure as e: - # Task has most likely expired, which means it's no longer a - # dependency for the purposes of this function. - if e.status_code == 404: - continue - raise e - - dependencies = task_def.get("dependencies", []) - if dependencies: - upstream_tasks.update(_get_deps(tuple(dependencies))) + task_defs = get_task_definitions(task_ids) + dependencies = set() + + for task_id, task_def in task_defs.items(): + dependencies |= set(task_def.get("dependencies", [])) + if dependencies: + upstream_tasks.update(_get_deps(tuple(dependencies))) return copy.deepcopy(upstream_tasks) diff --git a/test/test_util_taskcluster.py b/test/test_util_taskcluster.py index 88752585..d6c8c00b 100644 --- a/test/test_util_taskcluster.py +++ b/test/test_util_taskcluster.py @@ -3,6 +3,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. import datetime +import json from unittest.mock import MagicMock import pytest @@ -309,6 +310,7 @@ def test_get_task_url(root_url): def test_get_task_definition(responses, root_url): tid = "abc123" tc.get_taskcluster_client.cache_clear() + tc._task_definitions_cache.cache.clear() responses.get( f"{root_url}/api/queue/v1/task/{tid}", @@ -318,6 +320,40 @@ def test_get_task_definition(responses, root_url): assert result == {"payload": "blah"} +def test_get_task_definitions(responses, root_url): + tid = ( + "abc123", + "def456", + ) + tc.get_taskcluster_client.cache_clear() + tc._task_definitions_cache.cache.clear() + + task_definitions = { + "abc123": {"payload": "blah"}, + "def456": {"payload": "foobar"}, + } + + def tasks_callback(request): + payload = json.loads(request.body) + resp_body = { + "tasks": [ + {"taskId": task_id, "task": task_definitions[task_id]} + for task_id in payload["taskIds"] + ] + } + return (200, [], json.dumps(resp_body)) + + responses.add_callback( + responses.POST, + f"{root_url}/api/queue/v1/tasks", + callback=tasks_callback, + ) + result = tc.get_task_definitions(tid) + assert result == task_definitions + result = tc.get_task_definition(tid[0]) + assert result == {"payload": "blah"} + + def test_cancel_task(responses, root_url): tid = "abc123" tc.get_taskcluster_client.cache_clear() @@ -487,8 +523,7 @@ def test_list_task_group_incomplete_tasks(responses, root_url): def test_get_ancestors(responses, root_url): - tc.get_task_definition.cache_clear() - tc._get_deps.cache_clear() + tc._task_definitions_cache.cache.clear() tc.get_taskcluster_client.cache_clear() task_definitions = { @@ -518,12 +553,20 @@ def test_get_ancestors(responses, root_url): }, } - # Mock API responses for each task definition - for task_id, definition in task_definitions.items(): - responses.get( - f"{root_url}/api/queue/v1/task/{task_id}", - json=definition, - ) + # Mock API response for task definitions + def tasks_callback(request): + payload = json.loads(request.body) + resp_body = { + "tasks": [ + {"taskId": task_id, "task": task_definitions[task_id]} + for task_id in payload["taskIds"] + ] + } + return (200, [], json.dumps(resp_body)) + + responses.add_callback( + responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback + ) got = tc.get_ancestors(["bbb", "fff"]) expected = { @@ -536,8 +579,7 @@ def test_get_ancestors(responses, root_url): def test_get_ancestors_string(responses, root_url): - tc.get_task_definition.cache_clear() - tc._get_deps.cache_clear() + tc._task_definitions_cache.cache.clear() tc.get_taskcluster_client.cache_clear() task_definitions = { @@ -567,12 +609,20 @@ def test_get_ancestors_string(responses, root_url): }, } - # Mock API responses for each task definition - for task_id, definition in task_definitions.items(): - responses.get( - f"{root_url}/api/queue/v1/task/{task_id}", - json=definition, - ) + # Mock API response for task definitions + def tasks_callback(request): + payload = json.loads(request.body) + resp_body = { + "tasks": [ + {"taskId": task_id, "task": task_definitions[task_id]} + for task_id in payload["taskIds"] + ] + } + return (200, [], json.dumps(resp_body)) + + responses.add_callback( + responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback + ) got = tc.get_ancestors("fff") expected = {