Skip to content

Commit ecc1b94

Browse files
committed
Batch task definition lookups in get_ancestors
Add `get_task_definitions` function to the taskcluster util module, sharing a cache with the existing `get_task_definition`, and making a request for each batch of 1000 tasks, instead of one per task.
1 parent 37d045e commit ecc1b94

File tree

2 files changed

+111
-42
lines changed

2 files changed

+111
-42
lines changed

src/taskgraph/util/taskcluster.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,40 @@ def get_task_url(task_id):
268268
return task_tmpl.format(task_id)
269269

270270

271-
@functools.cache
272-
def get_task_definition(task_id):
273-
queue = get_taskcluster_client("queue")
274-
return queue.task(task_id)
271+
class TaskDefinitionsCache:
272+
def __init__(self):
273+
self.cache = {}
274+
275+
def get_task_definition(self, task_id):
276+
if task_id not in self.cache:
277+
queue = get_taskcluster_client("queue")
278+
self.cache[task_id] = queue.task(task_id)
279+
return self.cache[task_id]
280+
281+
def get_task_definitions(self, task_ids):
282+
missing_task_ids = list(set(task_ids) - set(self.cache))
283+
if missing_task_ids:
284+
queue = get_taskcluster_client("queue")
285+
286+
def pagination_handler(response):
287+
self.cache.update(
288+
{task["taskId"]: task["task"] for task in response["tasks"]}
289+
)
290+
291+
queue.tasks(
292+
payload={"taskIds": missing_task_ids},
293+
paginationHandler=pagination_handler,
294+
)
295+
return {
296+
task_id: self.cache[task_id]
297+
for task_id in task_ids
298+
if task_id in self.cache
299+
}
300+
301+
302+
_task_definitions_cache = TaskDefinitionsCache()
303+
get_task_definition = _task_definitions_cache.get_task_definition
304+
get_task_definitions = _task_definitions_cache.get_task_definitions
275305

276306

277307
def cancel_task(task_id):
@@ -430,22 +460,20 @@ def pagination_handler(response):
430460
return incomplete_tasks
431461

432462

433-
@functools.cache
434463
def _get_deps(task_ids):
435464
upstream_tasks = {}
436-
for task_id in task_ids:
437-
task_def = get_task_definition(task_id)
438-
if not task_def:
439-
continue
440-
465+
task_defs = get_task_definitions(task_ids)
466+
dependencies = set()
467+
for task_id, task_def in task_defs.items():
441468
metadata = task_def.get("metadata", {}) # type: ignore
442469
name = metadata.get("name") # type: ignore
443470
if name:
444471
upstream_tasks[task_id] = name
445472

446-
dependencies = task_def.get("dependencies", [])
447-
if dependencies:
448-
upstream_tasks.update(_get_deps(tuple(dependencies)))
473+
dependencies |= set(task_def.get("dependencies", []))
474+
475+
if dependencies:
476+
upstream_tasks.update(_get_deps(tuple(dependencies)))
449477

450478
return upstream_tasks
451479

@@ -464,18 +492,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]:
464492
if isinstance(task_ids, str):
465493
task_ids = [task_ids]
466494

467-
for task_id in task_ids:
468-
try:
469-
task_def = get_task_definition(task_id)
470-
except taskcluster.TaskclusterRestFailure as e:
471-
# Task has most likely expired, which means it's no longer a
472-
# dependency for the purposes of this function.
473-
if e.status_code == 404:
474-
continue
475-
raise e
476-
477-
dependencies = task_def.get("dependencies", [])
478-
if dependencies:
479-
upstream_tasks.update(_get_deps(tuple(dependencies)))
495+
task_defs = get_task_definitions(task_ids)
496+
dependencies = set()
497+
498+
for task_id, task_def in task_defs.items():
499+
dependencies |= set(task_def.get("dependencies", []))
500+
if dependencies:
501+
upstream_tasks.update(_get_deps(tuple(dependencies)))
480502

481503
return copy.deepcopy(upstream_tasks)

test/test_util_taskcluster.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import datetime
6+
import json
67
from unittest.mock import MagicMock
78

89
import pytest
@@ -309,6 +310,7 @@ def test_get_task_url(root_url):
309310
def test_get_task_definition(responses, root_url):
310311
tid = "abc123"
311312
tc.get_taskcluster_client.cache_clear()
313+
tc._task_definitions_cache.cache.clear()
312314

313315
responses.get(
314316
f"{root_url}/api/queue/v1/task/{tid}",
@@ -318,6 +320,37 @@ def test_get_task_definition(responses, root_url):
318320
assert result == {"payload": "blah"}
319321

320322

323+
def test_get_task_definitions(responses, root_url):
324+
tid = ("abc123", "def456",)
325+
tc.get_taskcluster_client.cache_clear()
326+
tc._task_definitions_cache.cache.clear()
327+
328+
task_definitions = {
329+
"abc123": {"payload": "blah"},
330+
"def456": {"payload": "foobar"},
331+
}
332+
333+
def tasks_callback(request):
334+
payload = json.loads(request.body)
335+
resp_body = {
336+
"tasks": [
337+
{"taskId": task_id, "task": task_definitions[task_id]}
338+
for task_id in payload["taskIds"]
339+
]
340+
}
341+
return (200, [], json.dumps(resp_body))
342+
343+
responses.add_callback(
344+
responses.POST,
345+
f"{root_url}/api/queue/v1/tasks",
346+
callback=tasks_callback,
347+
)
348+
result = tc.get_task_definitions(tid)
349+
assert result == task_definitions
350+
result = tc.get_task_definition(tid[0])
351+
assert result == {"payload": "blah"}
352+
353+
321354
def test_cancel_task(responses, root_url):
322355
tid = "abc123"
323356
tc.get_taskcluster_client.cache_clear()
@@ -487,8 +520,7 @@ def test_list_task_group_incomplete_tasks(responses, root_url):
487520

488521

489522
def test_get_ancestors(responses, root_url):
490-
tc.get_task_definition.cache_clear()
491-
tc._get_deps.cache_clear()
523+
tc._task_definitions_cache.cache.clear()
492524
tc.get_taskcluster_client.cache_clear()
493525

494526
task_definitions = {
@@ -518,12 +550,20 @@ def test_get_ancestors(responses, root_url):
518550
},
519551
}
520552

521-
# Mock API responses for each task definition
522-
for task_id, definition in task_definitions.items():
523-
responses.get(
524-
f"{root_url}/api/queue/v1/task/{task_id}",
525-
json=definition,
526-
)
553+
# Mock API response for task definitions
554+
def tasks_callback(request):
555+
payload = json.loads(request.body)
556+
resp_body = {
557+
"tasks": [
558+
{"taskId": task_id, "task": task_definitions[task_id]}
559+
for task_id in payload["taskIds"]
560+
]
561+
}
562+
return (200, [], json.dumps(resp_body))
563+
564+
responses.add_callback(
565+
responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback
566+
)
527567

528568
got = tc.get_ancestors(["bbb", "fff"])
529569
expected = {
@@ -536,8 +576,7 @@ def test_get_ancestors(responses, root_url):
536576

537577

538578
def test_get_ancestors_string(responses, root_url):
539-
tc.get_task_definition.cache_clear()
540-
tc._get_deps.cache_clear()
579+
tc._task_definitions_cache.cache.clear()
541580
tc.get_taskcluster_client.cache_clear()
542581

543582
task_definitions = {
@@ -567,12 +606,20 @@ def test_get_ancestors_string(responses, root_url):
567606
},
568607
}
569608

570-
# Mock API responses for each task definition
571-
for task_id, definition in task_definitions.items():
572-
responses.get(
573-
f"{root_url}/api/queue/v1/task/{task_id}",
574-
json=definition,
575-
)
609+
# Mock API response for task definitions
610+
def tasks_callback(request):
611+
payload = json.loads(request.body)
612+
resp_body = {
613+
"tasks": [
614+
{"taskId": task_id, "task": task_definitions[task_id]}
615+
for task_id in payload["taskIds"]
616+
]
617+
}
618+
return (200, [], json.dumps(resp_body))
619+
620+
responses.add_callback(
621+
responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback
622+
)
576623

577624
got = tc.get_ancestors("fff")
578625
expected = {

0 commit comments

Comments
 (0)