Skip to content

Commit 343b3af

Browse files
committed
Cleanup
1 parent fb4fc2d commit 343b3af

1 file changed

Lines changed: 73 additions & 93 deletions

File tree

  • app/grandchallenge/components/backends

app/grandchallenge/components/backends/base.py

Lines changed: 73 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import functools
23
import io
34
import json
45
import logging
@@ -15,6 +16,7 @@
1516
import boto3
1617
import botocore
1718
from asgiref.sync import async_to_sync
19+
from botocore.config import Config
1820
from django.conf import settings
1921
from django.core.exceptions import SuspiciousFileOperation, ValidationError
2022
from django.db import transaction
@@ -44,7 +46,9 @@
4446
logger = logging.getLogger(__name__)
4547

4648
MAX_SPOOL_SIZE = 1_000_000_000 # 1GB
47-
CONCURRENCY = 128
49+
50+
CONCURRENCY = 50
51+
BOTO_CONFIG = Config(max_pool_connections=120)
4852

4953

5054
class JobParams(NamedTuple):
@@ -124,41 +128,34 @@ def list_and_delete_objects_from_prefix(*, s3_client, bucket, prefix):
124128
)
125129

126130

127-
class AsyncTaskDefinition(NamedTuple):
128-
# semaphore and session (aioboto3) will get added to
129-
# the kwargs by the caller
130-
method: callable
131-
kwargs: dict
132-
133-
134131
async def s3_copy(
135-
*, source_bucket, source_key, target_bucket, target_key, semaphore, session
132+
*,
133+
source_bucket,
134+
source_key,
135+
target_bucket,
136+
target_key,
137+
semaphore,
138+
s3_client,
136139
):
137140
async with semaphore:
138-
async with session.client(
139-
"s3", endpoint_url=settings.AWS_S3_ENDPOINT_URL
140-
) as s3_client:
141-
await s3_client.copy(
142-
CopySource={"Bucket": source_bucket, "Key": source_key},
143-
Bucket=target_bucket,
144-
Key=target_key,
145-
)
141+
await s3_client.copy(
142+
CopySource={"Bucket": source_bucket, "Key": source_key},
143+
Bucket=target_bucket,
144+
Key=target_key,
145+
)
146146

147147

148-
async def s3_upload_content(*, content, bucket, key, semaphore, session):
148+
async def s3_upload_content(*, content, bucket, key, semaphore, s3_client):
149149
async with semaphore:
150-
async with session.client(
151-
"s3", endpoint_url=settings.AWS_S3_ENDPOINT_URL
152-
) as s3_client:
153-
with io.BytesIO() as f:
154-
f.write(content)
155-
f.seek(0)
150+
with io.BytesIO() as f:
151+
f.write(content)
152+
f.seek(0)
156153

157-
await s3_client.upload_fileobj(
158-
Fileobj=f,
159-
Bucket=bucket,
160-
Key=key,
161-
)
154+
await s3_client.upload_fileobj(
155+
Fileobj=f,
156+
Bucket=bucket,
157+
Key=key,
158+
)
162159

163160

164161
class Executor(ABC):
@@ -192,15 +189,13 @@ def __init__(
192189

193190
def provision(self, *, input_civs, input_prefixes):
194191
# We cannot run everything async as it requires database access.
195-
# So first we gather the definitions of the async tasks that
196-
# need to be run, then execute them in the event loop for
197-
# the current thread using @async_to_sync.
198-
provisioning_task_definitions = (
199-
self._get_provisioning_task_definitions(
200-
input_civs=input_civs, input_prefixes=input_prefixes
201-
)
192+
# So first we gather the async tasks that need to be run,
193+
# then execute them in the event loop for the current thread
194+
# using a method wrapped in @async_to_sync.
195+
provisioning_tasks = self._get_provisioning_tasks(
196+
input_civs=input_civs, input_prefixes=input_prefixes
202197
)
203-
self._provision(task_definitions=provisioning_task_definitions)
198+
self._provision(tasks=provisioning_tasks)
204199

205200
@abstractmethod
206201
def execute(self): ...
@@ -410,40 +405,32 @@ def _get_key_and_relative_path(self, *, civ, input_prefixes):
410405
return key, relative_path
411406

412407
@async_to_sync
413-
async def _provision(self, *, task_definitions):
408+
async def _provision(self, *, tasks):
414409
semaphore = asyncio.Semaphore(CONCURRENCY)
415410
session = aioboto3.Session()
416411

417-
async with asyncio.TaskGroup() as task_group:
418-
for task_definition in task_definitions:
419-
task_group.create_task(
420-
task_definition.method(
421-
**task_definition.kwargs,
422-
semaphore=semaphore,
423-
session=session,
412+
async with session.client(
413+
"s3", endpoint_url=settings.AWS_S3_ENDPOINT_URL, config=BOTO_CONFIG
414+
) as s3_client:
415+
async with asyncio.TaskGroup() as task_group:
416+
for task in tasks:
417+
task_group.create_task(
418+
task(
419+
semaphore=semaphore,
420+
s3_client=s3_client,
421+
)
424422
)
425-
)
426423

427-
def _get_provisioning_task_definitions(
428-
self, *, input_civs, input_prefixes
429-
):
430-
input_provisioning_task_definitions = (
431-
self._get_input_provisioning_task_definitions(
432-
input_civs=input_civs, input_prefixes=input_prefixes
433-
)
434-
)
435-
auxiliary_data_provisioning_task_definitions = (
436-
self._get_auxiliary_data_provisioning_task_definitions()
424+
def _get_provisioning_tasks(self, *, input_civs, input_prefixes):
425+
input_provisioning_tasks = self._get_input_provisioning_tasks(
426+
input_civs=input_civs, input_prefixes=input_prefixes
437427
)
438428

439429
return (
440-
input_provisioning_task_definitions
441-
+ auxiliary_data_provisioning_task_definitions
430+
input_provisioning_tasks + self._auxiliary_data_provisioning_tasks
442431
)
443432

444-
def _get_input_provisioning_task_definitions(
445-
self, *, input_civs, input_prefixes
446-
):
433+
def _get_input_provisioning_tasks(self, *, input_civs, input_prefixes):
447434
invocation_inputs = []
448435

449436
tasks = []
@@ -453,9 +440,7 @@ def _get_input_provisioning_task_definitions(
453440
civ=civ, input_prefixes=input_prefixes
454441
)
455442

456-
tasks.append(
457-
self._get_civ_input_provisioning_task_definition(civ, key)
458-
)
443+
tasks.append(self._get_civ_input_provisioning_task(civ, key))
459444

460445
invocation_inputs.append(
461446
{
@@ -467,24 +452,24 @@ def _get_input_provisioning_task_definitions(
467452
)
468453

469454
tasks.append(
470-
self._get_create_invocation_json_task_definition(
455+
self._get_create_invocation_json_task(
471456
invocation_inputs=invocation_inputs
472457
)
473458
)
474459

475460
return tasks
476461

477-
def _get_civ_input_provisioning_task_definition(self, civ, key):
462+
def _get_civ_input_provisioning_task(self, civ, key):
478463
if civ.interface.super_kind == civ.interface.SuperKind.IMAGE:
479-
return self._get_copy_input_object_task_definition(
464+
return self._get_copy_input_object_task(
480465
src=civ.image_file, target_key=key
481466
)
482467
elif civ.interface.super_kind == civ.interface.SuperKind.FILE:
483-
return self._get_copy_input_object_task_definition(
468+
return self._get_copy_input_object_task(
484469
src=civ.file, target_key=key
485470
)
486471
elif civ.interface.super_kind == civ.interface.SuperKind.VALUE:
487-
return self._get_upload_input_content_task_definition(
472+
return self._get_upload_input_content_task(
488473
content=json.dumps(civ.value).encode("utf-8"),
489474
key=key,
490475
)
@@ -493,10 +478,8 @@ def _get_civ_input_provisioning_task_definition(self, civ, key):
493478
f"Unknown interface super kind: {civ.interface.super_kind}"
494479
)
495480

496-
def _get_create_invocation_json_task_definition(
497-
self, *, invocation_inputs
498-
):
499-
return self._get_upload_input_content_task_definition(
481+
def _get_create_invocation_json_task(self, *, invocation_inputs):
482+
return self._get_upload_input_content_task(
500483
content=json.dumps(
501484
[
502485
{
@@ -510,20 +493,21 @@ def _get_create_invocation_json_task_definition(
510493
key=self._invocation_key,
511494
)
512495

513-
def _get_auxiliary_data_provisioning_task_definitions(self):
496+
@property
497+
def _auxiliary_data_provisioning_tasks(self):
514498
tasks = []
515499

516500
if self._algorithm_model:
517501
tasks.append(
518-
self._get_copy_input_object_task_definition(
502+
self._get_copy_input_object_task(
519503
src=self._algorithm_model,
520504
target_key=self._algorithm_model_key,
521505
)
522506
)
523507

524508
if self._ground_truth:
525509
tasks.append(
526-
self._get_copy_input_object_task_definition(
510+
self._get_copy_input_object_task(
527511
src=self._ground_truth, target_key=self._ground_truth_key
528512
)
529513
)
@@ -547,26 +531,22 @@ def _with_inputs_json(self, *, input_civs):
547531
)
548532

549533
@staticmethod
550-
def _get_copy_input_object_task_definition(*, src, target_key):
551-
return AsyncTaskDefinition(
552-
method=s3_copy,
553-
kwargs={
554-
"source_bucket": src.storage.bucket.name,
555-
"source_key": src.name,
556-
"target_bucket": settings.COMPONENTS_INPUT_BUCKET_NAME,
557-
"target_key": target_key,
558-
},
534+
def _get_copy_input_object_task(*, src, target_key):
535+
return functools.partial(
536+
s3_copy,
537+
source_bucket=src.storage.bucket.name,
538+
source_key=src.name,
539+
target_bucket=settings.COMPONENTS_INPUT_BUCKET_NAME,
540+
target_key=target_key,
559541
)
560542

561543
@staticmethod
562-
def _get_upload_input_content_task_definition(*, content, key):
563-
return AsyncTaskDefinition(
564-
method=s3_upload_content,
565-
kwargs={
566-
"content": content,
567-
"bucket": settings.COMPONENTS_INPUT_BUCKET_NAME,
568-
"key": key,
569-
},
544+
def _get_upload_input_content_task(*, content, key):
545+
return functools.partial(
546+
s3_upload_content,
547+
content=content,
548+
bucket=settings.COMPONENTS_INPUT_BUCKET_NAME,
549+
key=key,
570550
)
571551

572552
def _get_task_return_code(self):

0 commit comments

Comments
 (0)