Skip to content

Commit 18f11ed

Browse files
authored
Async provisioning of component inputs (#4342)
See DIAGNijmegen/rse-roadmap#436
1 parent 22c31bb commit 18f11ed

6 files changed

Lines changed: 324 additions & 183 deletions

File tree

app/grandchallenge/components/backends/base.py

Lines changed: 151 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import functools
13
import io
24
import json
35
import logging
@@ -10,8 +12,11 @@
1012
from typing import NamedTuple
1113
from uuid import UUID
1214

15+
import aioboto3
1316
import boto3
1417
import botocore
18+
from asgiref.sync import async_to_sync
19+
from botocore.config import Config
1520
from django.conf import settings
1621
from django.core.exceptions import SuspiciousFileOperation, ValidationError
1722
from django.db import transaction
@@ -42,6 +47,9 @@
4247

4348
MAX_SPOOL_SIZE = 1_000_000_000 # 1GB
4449

50+
CONCURRENCY = 50
51+
BOTO_CONFIG = Config(max_pool_connections=120)
52+
4553

4654
class JobParams(NamedTuple):
4755
app_label: str
@@ -120,6 +128,36 @@ def list_and_delete_objects_from_prefix(*, s3_client, bucket, prefix):
120128
)
121129

122130

131+
async def s3_copy(
132+
*,
133+
source_bucket,
134+
source_key,
135+
target_bucket,
136+
target_key,
137+
semaphore,
138+
s3_client,
139+
):
140+
async with semaphore:
141+
await s3_client.copy(
142+
CopySource={"Bucket": source_bucket, "Key": source_key},
143+
Bucket=target_bucket,
144+
Key=target_key,
145+
)
146+
147+
148+
async def s3_upload_content(*, content, bucket, key, semaphore, s3_client):
149+
async with semaphore:
150+
with io.BytesIO() as f:
151+
f.write(content)
152+
f.seek(0)
153+
154+
await s3_client.upload_fileobj(
155+
Fileobj=f,
156+
Bucket=bucket,
157+
Key=key,
158+
)
159+
160+
123161
class Executor(ABC):
124162
def __init__(
125163
self,
@@ -150,10 +188,14 @@ def __init__(
150188
self._ground_truth = ground_truth
151189

152190
def provision(self, *, input_civs, input_prefixes):
153-
self._provision_inputs(
191+
# We cannot run everything async as it requires database access.
192+
# So first we gather the async tasks that need to be run,
193+
# then execute them in the event loop for the current thread
194+
# using a method wrapped in @async_to_sync.
195+
provisioning_tasks = self._get_provisioning_tasks(
154196
input_civs=input_civs, input_prefixes=input_prefixes
155197
)
156-
self._provision_auxilliary_data()
198+
self._provision(tasks=provisioning_tasks)
157199

158200
@abstractmethod
159201
def execute(self): ...
@@ -362,27 +404,43 @@ def _get_key_and_relative_path(self, *, civ, input_prefixes):
362404

363405
return key, relative_path
364406

365-
def _provision_inputs(self, *, input_civs, input_prefixes):
407+
@async_to_sync
408+
async def _provision(self, *, tasks):
409+
semaphore = asyncio.Semaphore(CONCURRENCY)
410+
session = aioboto3.Session()
411+
412+
async with session.client(
413+
"s3", endpoint_url=settings.AWS_S3_ENDPOINT_URL, config=BOTO_CONFIG
414+
) as s3_client:
415+
async with asyncio.TaskGroup() as task_group:
416+
for task in tasks:
417+
task_group.create_task(
418+
task(
419+
semaphore=semaphore,
420+
s3_client=s3_client,
421+
)
422+
)
423+
424+
def _get_provisioning_tasks(self, *, input_civs, input_prefixes):
425+
input_provisioning_tasks = self._get_input_provisioning_tasks(
426+
input_civs=input_civs, input_prefixes=input_prefixes
427+
)
428+
429+
return (
430+
input_provisioning_tasks + self._auxiliary_data_provisioning_tasks
431+
)
432+
433+
def _get_input_provisioning_tasks(self, *, input_civs, input_prefixes):
366434
invocation_inputs = []
367435

436+
tasks = []
437+
368438
for civ in self._with_inputs_json(input_civs=input_civs):
369439
key, relative_path = self._get_key_and_relative_path(
370440
civ=civ, input_prefixes=input_prefixes
371441
)
372442

373-
if civ.image:
374-
self._copy_input_file(src=civ.image_file, dest_key=key)
375-
elif civ.file:
376-
self._copy_input_file(src=civ.file, dest_key=key)
377-
else:
378-
with io.BytesIO() as f:
379-
f.write(json.dumps(civ.value).encode("utf-8"))
380-
f.seek(0)
381-
self._s3_client.upload_fileobj(
382-
Fileobj=f,
383-
Bucket=settings.COMPONENTS_INPUT_BUCKET_NAME,
384-
Key=key,
385-
)
443+
tasks.append(self._get_civ_input_provisioning_task(civ, key))
386444

387445
invocation_inputs.append(
388446
{
@@ -393,7 +451,68 @@ def _provision_inputs(self, *, input_civs, input_prefixes):
393451
}
394452
)
395453

396-
self._create_invocation_json(inputs=invocation_inputs)
454+
tasks.append(
455+
self._get_create_invocation_json_task(
456+
invocation_inputs=invocation_inputs
457+
)
458+
)
459+
460+
return tasks
461+
462+
def _get_civ_input_provisioning_task(self, civ, key):
463+
if civ.interface.super_kind == civ.interface.SuperKind.IMAGE:
464+
return self._get_copy_input_object_task(
465+
src=civ.image_file, target_key=key
466+
)
467+
elif civ.interface.super_kind == civ.interface.SuperKind.FILE:
468+
return self._get_copy_input_object_task(
469+
src=civ.file, target_key=key
470+
)
471+
elif civ.interface.super_kind == civ.interface.SuperKind.VALUE:
472+
return self._get_upload_input_content_task(
473+
content=json.dumps(civ.value).encode("utf-8"),
474+
key=key,
475+
)
476+
else:
477+
raise NotImplementedError(
478+
f"Unknown interface super kind: {civ.interface.super_kind}"
479+
)
480+
481+
def _get_create_invocation_json_task(self, *, invocation_inputs):
482+
return self._get_upload_input_content_task(
483+
content=json.dumps(
484+
[
485+
{
486+
"pk": self._job_id,
487+
"inputs": invocation_inputs,
488+
"output_bucket_name": settings.COMPONENTS_OUTPUT_BUCKET_NAME,
489+
"output_prefix": self._io_prefix,
490+
}
491+
]
492+
).encode("utf-8"),
493+
key=self._invocation_key,
494+
)
495+
496+
@property
497+
def _auxiliary_data_provisioning_tasks(self):
498+
tasks = []
499+
500+
if self._algorithm_model:
501+
tasks.append(
502+
self._get_copy_input_object_task(
503+
src=self._algorithm_model,
504+
target_key=self._algorithm_model_key,
505+
)
506+
)
507+
508+
if self._ground_truth:
509+
tasks.append(
510+
self._get_copy_input_object_task(
511+
src=self._ground_truth, target_key=self._ground_truth_key
512+
)
513+
)
514+
515+
return tasks
397516

398517
def _with_inputs_json(self, *, input_civs):
399518
"""
@@ -411,38 +530,23 @@ def _with_inputs_json(self, *, input_civs):
411530
),
412531
)
413532

414-
def _create_invocation_json(self, *, inputs):
415-
f = io.BytesIO(
416-
json.dumps(
417-
[
418-
{
419-
"pk": self._job_id,
420-
"inputs": inputs,
421-
"output_bucket_name": settings.COMPONENTS_OUTPUT_BUCKET_NAME,
422-
"output_prefix": self._io_prefix,
423-
}
424-
]
425-
).encode("utf-8")
426-
)
427-
self._s3_client.upload_fileobj(
428-
f, settings.COMPONENTS_INPUT_BUCKET_NAME, self._invocation_key
533+
@staticmethod
534+
def _get_copy_input_object_task(*, src, target_key):
535+
return functools.partial(
536+
s3_copy,
537+
source_bucket=src.storage.bucket.name,
538+
source_key=src.name,
539+
target_bucket=settings.COMPONENTS_INPUT_BUCKET_NAME,
540+
target_key=target_key,
429541
)
430542

431-
def _provision_auxilliary_data(self):
432-
if self._algorithm_model:
433-
self._copy_input_file(
434-
src=self._algorithm_model, dest_key=self._algorithm_model_key
435-
)
436-
if self._ground_truth:
437-
self._copy_input_file(
438-
src=self._ground_truth, dest_key=self._ground_truth_key
439-
)
440-
441-
def _copy_input_file(self, *, src, dest_key):
442-
self._s3_client.copy(
443-
CopySource={"Bucket": src.storage.bucket.name, "Key": src.name},
444-
Bucket=settings.COMPONENTS_INPUT_BUCKET_NAME,
445-
Key=dest_key,
543+
@staticmethod
544+
def _get_upload_input_content_task(*, content, key):
545+
return functools.partial(
546+
s3_upload_content,
547+
content=content,
548+
bucket=settings.COMPONENTS_INPUT_BUCKET_NAME,
549+
key=key,
446550
)
447551

448552
def _get_task_return_code(self):

app/tests/components_tests/test_backends.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_filter_members,
1212
user_error,
1313
)
14+
from grandchallenge.components.models import InterfaceKindChoices
1415
from grandchallenge.components.schemas import GPUTypeChoices
1516
from tests.components_tests.factories import ComponentInterfaceValueFactory
1617
from tests.components_tests.resources.backends import IOCopyExecutor
@@ -128,7 +129,9 @@ def test_inputs_json(settings):
128129
use_warm_pool=False,
129130
)
130131

131-
civ1, civ2 = ComponentInterfaceValueFactory.create_batch(2)
132+
civ1, civ2 = ComponentInterfaceValueFactory.create_batch(
133+
2, interface__kind=InterfaceKindChoices.ANY
134+
)
132135

133136
executor.provision(input_civs=[civ1, civ2], input_prefixes={})
134137

app/tests/conftest.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import warnings
32
import zipfile
43
from collections import namedtuple
54
from pathlib import Path
@@ -84,15 +83,6 @@ def django_db_setup(django_db_setup, django_db_blocker):
8483
site.save()
8584

8685

87-
def pytest_itemcollected(item):
88-
if item.get_closest_marker("playwright") is not None:
89-
# See https://github.com/microsoft/playwright-pytest/issues/29
90-
warnings.warn( # noqa: B028
91-
"Setting DJANGO_ALLOW_ASYNC_UNSAFE for playwright support"
92-
)
93-
os.environ["DJANGO_ALLOW_ASYNC_UNSAFE"] = "true"
94-
95-
9686
class ChallengeSet(NamedTuple):
9787
challenge: ChallengeFactory
9888
creator: UserFactory

0 commit comments

Comments
 (0)