Skip to content

Commit 17995e8

Browse files
committed
Batch task definition lookups in get_ancestors
Add `get_task_definitions` function to the taskcluster util module, sharing a cache with the existing `get_task_definition`, and making a request for each batch of 1000 tasks, instead of one per task.
1 parent 37d045e commit 17995e8

File tree

2 files changed

+114
-42
lines changed

2 files changed

+114
-42
lines changed

src/taskgraph/util/taskcluster.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -268,10 +268,40 @@ def get_task_url(task_id):
268268
return task_tmpl.format(task_id)
269269

270270

271-
@functools.cache
272-
def get_task_definition(task_id):
273-
queue = get_taskcluster_client("queue")
274-
return queue.task(task_id)
271+
class TaskDefinitionsCache:
272+
def __init__(self):
273+
self.cache = {}
274+
275+
def get_task_definition(self, task_id):
276+
if task_id not in self.cache:
277+
queue = get_taskcluster_client("queue")
278+
self.cache[task_id] = queue.task(task_id)
279+
return self.cache[task_id]
280+
281+
def get_task_definitions(self, task_ids):
282+
missing_task_ids = list(set(task_ids) - set(self.cache))
283+
if missing_task_ids:
284+
queue = get_taskcluster_client("queue")
285+
286+
def pagination_handler(response):
287+
self.cache.update(
288+
{task["taskId"]: task["task"] for task in response["tasks"]}
289+
)
290+
291+
queue.tasks(
292+
payload={"taskIds": missing_task_ids},
293+
paginationHandler=pagination_handler,
294+
)
295+
return {
296+
task_id: self.cache[task_id]
297+
for task_id in task_ids
298+
if task_id in self.cache
299+
}
300+
301+
302+
_task_definitions_cache = TaskDefinitionsCache()
303+
get_task_definition = _task_definitions_cache.get_task_definition
304+
get_task_definitions = _task_definitions_cache.get_task_definitions
275305

276306

277307
def cancel_task(task_id):
@@ -430,22 +460,20 @@ def pagination_handler(response):
430460
return incomplete_tasks
431461

432462

433-
@functools.cache
434463
def _get_deps(task_ids):
435464
upstream_tasks = {}
436-
for task_id in task_ids:
437-
task_def = get_task_definition(task_id)
438-
if not task_def:
439-
continue
440-
465+
task_defs = get_task_definitions(task_ids)
466+
dependencies = set()
467+
for task_id, task_def in task_defs.items():
441468
metadata = task_def.get("metadata", {}) # type: ignore
442469
name = metadata.get("name") # type: ignore
443470
if name:
444471
upstream_tasks[task_id] = name
445472

446-
dependencies = task_def.get("dependencies", [])
447-
if dependencies:
448-
upstream_tasks.update(_get_deps(tuple(dependencies)))
473+
dependencies |= set(task_def.get("dependencies", []))
474+
475+
if dependencies:
476+
upstream_tasks.update(_get_deps(tuple(dependencies)))
449477

450478
return upstream_tasks
451479

@@ -464,18 +492,12 @@ def get_ancestors(task_ids: Union[list[str], str]) -> dict[str, str]:
464492
if isinstance(task_ids, str):
465493
task_ids = [task_ids]
466494

467-
for task_id in task_ids:
468-
try:
469-
task_def = get_task_definition(task_id)
470-
except taskcluster.TaskclusterRestFailure as e:
471-
# Task has most likely expired, which means it's no longer a
472-
# dependency for the purposes of this function.
473-
if e.status_code == 404:
474-
continue
475-
raise e
476-
477-
dependencies = task_def.get("dependencies", [])
478-
if dependencies:
479-
upstream_tasks.update(_get_deps(tuple(dependencies)))
495+
task_defs = get_task_definitions(task_ids)
496+
dependencies = set()
497+
498+
for task_id, task_def in task_defs.items():
499+
dependencies |= set(task_def.get("dependencies", []))
500+
if dependencies:
501+
upstream_tasks.update(_get_deps(tuple(dependencies)))
480502

481503
return copy.deepcopy(upstream_tasks)

test/test_util_taskcluster.py

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import datetime
6+
import json
67
from unittest.mock import MagicMock
78

89
import pytest
@@ -309,6 +310,7 @@ def test_get_task_url(root_url):
309310
def test_get_task_definition(responses, root_url):
310311
tid = "abc123"
311312
tc.get_taskcluster_client.cache_clear()
313+
tc._task_definitions_cache.cache.clear()
312314

313315
responses.get(
314316
f"{root_url}/api/queue/v1/task/{tid}",
@@ -318,6 +320,40 @@ def test_get_task_definition(responses, root_url):
318320
assert result == {"payload": "blah"}
319321

320322

323+
def test_get_task_definitions(responses, root_url):
324+
tid = (
325+
"abc123",
326+
"def456",
327+
)
328+
tc.get_taskcluster_client.cache_clear()
329+
tc._task_definitions_cache.cache.clear()
330+
331+
task_definitions = {
332+
"abc123": {"payload": "blah"},
333+
"def456": {"payload": "foobar"},
334+
}
335+
336+
def tasks_callback(request):
337+
payload = json.loads(request.body)
338+
resp_body = {
339+
"tasks": [
340+
{"taskId": task_id, "task": task_definitions[task_id]}
341+
for task_id in payload["taskIds"]
342+
]
343+
}
344+
return (200, [], json.dumps(resp_body))
345+
346+
responses.add_callback(
347+
responses.POST,
348+
f"{root_url}/api/queue/v1/tasks",
349+
callback=tasks_callback,
350+
)
351+
result = tc.get_task_definitions(tid)
352+
assert result == task_definitions
353+
result = tc.get_task_definition(tid[0])
354+
assert result == {"payload": "blah"}
355+
356+
321357
def test_cancel_task(responses, root_url):
322358
tid = "abc123"
323359
tc.get_taskcluster_client.cache_clear()
@@ -487,8 +523,7 @@ def test_list_task_group_incomplete_tasks(responses, root_url):
487523

488524

489525
def test_get_ancestors(responses, root_url):
490-
tc.get_task_definition.cache_clear()
491-
tc._get_deps.cache_clear()
526+
tc._task_definitions_cache.cache.clear()
492527
tc.get_taskcluster_client.cache_clear()
493528

494529
task_definitions = {
@@ -518,12 +553,20 @@ def test_get_ancestors(responses, root_url):
518553
},
519554
}
520555

521-
# Mock API responses for each task definition
522-
for task_id, definition in task_definitions.items():
523-
responses.get(
524-
f"{root_url}/api/queue/v1/task/{task_id}",
525-
json=definition,
526-
)
556+
# Mock API response for task definitions
557+
def tasks_callback(request):
558+
payload = json.loads(request.body)
559+
resp_body = {
560+
"tasks": [
561+
{"taskId": task_id, "task": task_definitions[task_id]}
562+
for task_id in payload["taskIds"]
563+
]
564+
}
565+
return (200, [], json.dumps(resp_body))
566+
567+
responses.add_callback(
568+
responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback
569+
)
527570

528571
got = tc.get_ancestors(["bbb", "fff"])
529572
expected = {
@@ -536,8 +579,7 @@ def test_get_ancestors(responses, root_url):
536579

537580

538581
def test_get_ancestors_string(responses, root_url):
539-
tc.get_task_definition.cache_clear()
540-
tc._get_deps.cache_clear()
582+
tc._task_definitions_cache.cache.clear()
541583
tc.get_taskcluster_client.cache_clear()
542584

543585
task_definitions = {
@@ -567,12 +609,20 @@ def test_get_ancestors_string(responses, root_url):
567609
},
568610
}
569611

570-
# Mock API responses for each task definition
571-
for task_id, definition in task_definitions.items():
572-
responses.get(
573-
f"{root_url}/api/queue/v1/task/{task_id}",
574-
json=definition,
575-
)
612+
# Mock API response for task definitions
613+
def tasks_callback(request):
614+
payload = json.loads(request.body)
615+
resp_body = {
616+
"tasks": [
617+
{"taskId": task_id, "task": task_definitions[task_id]}
618+
for task_id in payload["taskIds"]
619+
]
620+
}
621+
return (200, [], json.dumps(resp_body))
622+
623+
responses.add_callback(
624+
responses.POST, f"{root_url}/api/queue/v1/tasks", callback=tasks_callback
625+
)
576626

577627
got = tc.get_ancestors("fff")
578628
expected = {

0 commit comments

Comments
 (0)