Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions src/taskgraph/util/taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,40 @@ def get_task_url(task_id):
return task_tmpl.format(task_id)


@functools.cache
def get_task_definition(task_id):
queue = get_taskcluster_client("queue")
return queue.task(task_id)
class TaskDefinitionsCache:
def __init__(self):
self.cache = {}

def get_task_definition(self, task_id):
if task_id not in self.cache:
queue = get_taskcluster_client("queue")
self.cache[task_id] = queue.task(task_id)
return self.cache[task_id]

def get_task_definitions(self, task_ids):
missing_task_ids = list(set(task_ids) - set(self.cache))
if missing_task_ids:
queue = get_taskcluster_client("queue")

def pagination_handler(response):
self.cache.update(
{task["taskId"]: task["task"] for task in response["tasks"]}
)

queue.tasks(
payload={"taskIds": missing_task_ids},
paginationHandler=pagination_handler,
)
return {
task_id: self.cache[task_id]
for task_id in task_ids
if task_id in self.cache
}


_task_definitions_cache = TaskDefinitionsCache()
get_task_definition = _task_definitions_cache.get_task_definition
get_task_definitions = _task_definitions_cache.get_task_definitions


def cancel_task(task_id):
Expand Down Expand Up @@ -430,22 +460,20 @@ def pagination_handler(response):
return incomplete_tasks


@functools.cache
def _get_deps(task_ids):
upstream_tasks = {}
for task_id in task_ids:
task_def = get_task_definition(task_id)
if not task_def:
continue

task_defs = get_task_definitions(task_ids)
dependencies = set()
for task_id, task_def in task_defs.items():
metadata = task_def.get("metadata", {}) # type: ignore
name = metadata.get("name") # type: ignore
if name:
upstream_tasks[task_id] = name

dependencies = task_def.get("dependencies", [])
if dependencies:
upstream_tasks.update(_get_deps(tuple(dependencies)))
dependencies |= set(task_def.get("dependencies", []))

if dependencies:
upstream_tasks.update(_get_deps(tuple(dependencies)))

return upstream_tasks

Expand All @@ -464,18 +492,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]:
if isinstance(task_ids, str):
task_ids = [task_ids]

for task_id in task_ids:
try:
task_def = get_task_definition(task_id)
except taskcluster.TaskclusterRestFailure as e:
# Task has most likely expired, which means it's no longer a
# dependency for the purposes of this function.
if e.status_code == 404:
continue
raise e

dependencies = task_def.get("dependencies", [])
if dependencies:
upstream_tasks.update(_get_deps(tuple(dependencies)))
task_defs = get_task_definitions(task_ids)
dependencies = set()

for task_id, task_def in task_defs.items():
dependencies |= set(task_def.get("dependencies", []))
if dependencies:
upstream_tasks.update(_get_deps(tuple(dependencies)))

return copy.deepcopy(upstream_tasks)
82 changes: 66 additions & 16 deletions test/test_util_taskcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import datetime
import json
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -309,6 +310,7 @@ def test_get_task_url(root_url):
def test_get_task_definition(responses, root_url):
tid = "abc123"
tc.get_taskcluster_client.cache_clear()
tc._task_definitions_cache.cache.clear()

responses.get(
f"{root_url}/api/queue/v1/task/{tid}",
Expand All @@ -318,6 +320,40 @@ def test_get_task_definition(responses, root_url):
assert result == {"payload": "blah"}


def test_get_task_definitions(responses, root_url):
tid = (
"abc123",
"def456",
)
tc.get_taskcluster_client.cache_clear()
tc._task_definitions_cache.cache.clear()

task_definitions = {
"abc123": {"payload": "blah"},
"def456": {"payload": "foobar"},
}

def tasks_callback(request):
payload = json.loads(request.body)
resp_body = {
"tasks": [
{"taskId": task_id, "task": task_definitions[task_id]}
for task_id in payload["taskIds"]
]
}
return (200, [], json.dumps(resp_body))

responses.add_callback(
responses.POST,
f"{root_url}/api/queue/v1/tasks",
callback=tasks_callback,
)
result = tc.get_task_definitions(tid)
assert result == task_definitions
result = tc.get_task_definition(tid[0])
assert result == {"payload": "blah"}


def test_cancel_task(responses, root_url):
tid = "abc123"
tc.get_taskcluster_client.cache_clear()
Expand Down Expand Up @@ -487,8 +523,7 @@ def test_list_task_group_incomplete_tasks(responses, root_url):


def test_get_ancestors(responses, root_url):
tc.get_task_definition.cache_clear()
tc._get_deps.cache_clear()
tc._task_definitions_cache.cache.clear()
tc.get_taskcluster_client.cache_clear()

task_definitions = {
Expand Down Expand Up @@ -518,12 +553,20 @@ def test_get_ancestors(responses, root_url):
},
}

# Mock API responses for each task definition
for task_id, definition in task_definitions.items():
responses.get(
f"{root_url}/api/queue/v1/task/{task_id}",
json=definition,
)
# Mock API response for task definitions
def tasks_callback(request):
payload = json.loads(request.body)
resp_body = {
"tasks": [
{"taskId": task_id, "task": task_definitions[task_id]}
for task_id in payload["taskIds"]
]
}
return (200, [], json.dumps(resp_body))

responses.add_callback(
responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback
)

got = tc.get_ancestors(["bbb", "fff"])
expected = {
Expand All @@ -536,8 +579,7 @@ def test_get_ancestors(responses, root_url):


def test_get_ancestors_string(responses, root_url):
tc.get_task_definition.cache_clear()
tc._get_deps.cache_clear()
tc._task_definitions_cache.cache.clear()
tc.get_taskcluster_client.cache_clear()

task_definitions = {
Expand Down Expand Up @@ -567,12 +609,20 @@ def test_get_ancestors_string(responses, root_url):
},
}

# Mock API responses for each task definition
for task_id, definition in task_definitions.items():
responses.get(
f"{root_url}/api/queue/v1/task/{task_id}",
json=definition,
)
# Mock API response for task definitions
def tasks_callback(request):
payload = json.loads(request.body)
resp_body = {
"tasks": [
{"taskId": task_id, "task": task_definitions[task_id]}
for task_id in payload["taskIds"]
]
}
return (200, [], json.dumps(resp_body))

responses.add_callback(
responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback
)

got = tc.get_ancestors("fff")
expected = {
Expand Down
Loading