1+ import asyncio
2+ import functools
13import io
24import json
35import logging
1012from typing import NamedTuple
1113from uuid import UUID
1214
15+ import aioboto3
1316import boto3
1417import botocore
18+ from asgiref .sync import async_to_sync
19+ from botocore .config import Config
1520from django .conf import settings
1621from django .core .exceptions import SuspiciousFileOperation , ValidationError
1722from django .db import transaction
4247
4348MAX_SPOOL_SIZE = 1_000_000_000 # 1GB
4449
50+ CONCURRENCY = 50
51+ BOTO_CONFIG = Config (max_pool_connections = 120 )
52+
4553
4654class JobParams (NamedTuple ):
4755 app_label : str
@@ -120,6 +128,36 @@ def list_and_delete_objects_from_prefix(*, s3_client, bucket, prefix):
120128 )
121129
122130
131+ async def s3_copy (
132+ * ,
133+ source_bucket ,
134+ source_key ,
135+ target_bucket ,
136+ target_key ,
137+ semaphore ,
138+ s3_client ,
139+ ):
140+ async with semaphore :
141+ await s3_client .copy (
142+ CopySource = {"Bucket" : source_bucket , "Key" : source_key },
143+ Bucket = target_bucket ,
144+ Key = target_key ,
145+ )
146+
147+
148+ async def s3_upload_content (* , content , bucket , key , semaphore , s3_client ):
149+ async with semaphore :
150+ with io .BytesIO () as f :
151+ f .write (content )
152+ f .seek (0 )
153+
154+ await s3_client .upload_fileobj (
155+ Fileobj = f ,
156+ Bucket = bucket ,
157+ Key = key ,
158+ )
159+
160+
123161class Executor (ABC ):
124162 def __init__ (
125163 self ,
@@ -150,10 +188,14 @@ def __init__(
150188 self ._ground_truth = ground_truth
151189
152190 def provision (self , * , input_civs , input_prefixes ):
153- self ._provision_inputs (
191+ # We cannot run everything async as it requires database access.
192+ # So first we gather the async tasks that need to be run,
193+ # then execute them in the event loop for the current thread
194+ # using a method wrapped in @async_to_sync.
195+ provisioning_tasks = self ._get_provisioning_tasks (
154196 input_civs = input_civs , input_prefixes = input_prefixes
155197 )
156- self ._provision_auxilliary_data ( )
198+ self ._provision ( tasks = provisioning_tasks )
157199
158200 @abstractmethod
159201 def execute (self ): ...
@@ -362,27 +404,43 @@ def _get_key_and_relative_path(self, *, civ, input_prefixes):
362404
363405 return key , relative_path
364406
365- def _provision_inputs (self , * , input_civs , input_prefixes ):
407+ @async_to_sync
408+ async def _provision (self , * , tasks ):
409+ semaphore = asyncio .Semaphore (CONCURRENCY )
410+ session = aioboto3 .Session ()
411+
412+ async with session .client (
413+ "s3" , endpoint_url = settings .AWS_S3_ENDPOINT_URL , config = BOTO_CONFIG
414+ ) as s3_client :
415+ async with asyncio .TaskGroup () as task_group :
416+ for task in tasks :
417+ task_group .create_task (
418+ task (
419+ semaphore = semaphore ,
420+ s3_client = s3_client ,
421+ )
422+ )
423+
424+ def _get_provisioning_tasks (self , * , input_civs , input_prefixes ):
425+ input_provisioning_tasks = self ._get_input_provisioning_tasks (
426+ input_civs = input_civs , input_prefixes = input_prefixes
427+ )
428+
429+ return (
430+ input_provisioning_tasks + self ._auxiliary_data_provisioning_tasks
431+ )
432+
433+ def _get_input_provisioning_tasks (self , * , input_civs , input_prefixes ):
366434 invocation_inputs = []
367435
436+ tasks = []
437+
368438 for civ in self ._with_inputs_json (input_civs = input_civs ):
369439 key , relative_path = self ._get_key_and_relative_path (
370440 civ = civ , input_prefixes = input_prefixes
371441 )
372442
373- if civ .image :
374- self ._copy_input_file (src = civ .image_file , dest_key = key )
375- elif civ .file :
376- self ._copy_input_file (src = civ .file , dest_key = key )
377- else :
378- with io .BytesIO () as f :
379- f .write (json .dumps (civ .value ).encode ("utf-8" ))
380- f .seek (0 )
381- self ._s3_client .upload_fileobj (
382- Fileobj = f ,
383- Bucket = settings .COMPONENTS_INPUT_BUCKET_NAME ,
384- Key = key ,
385- )
443+ tasks .append (self ._get_civ_input_provisioning_task (civ , key ))
386444
387445 invocation_inputs .append (
388446 {
@@ -393,7 +451,68 @@ def _provision_inputs(self, *, input_civs, input_prefixes):
393451 }
394452 )
395453
396- self ._create_invocation_json (inputs = invocation_inputs )
454+ tasks .append (
455+ self ._get_create_invocation_json_task (
456+ invocation_inputs = invocation_inputs
457+ )
458+ )
459+
460+ return tasks
461+
462+ def _get_civ_input_provisioning_task (self , civ , key ):
463+ if civ .interface .super_kind == civ .interface .SuperKind .IMAGE :
464+ return self ._get_copy_input_object_task (
465+ src = civ .image_file , target_key = key
466+ )
467+ elif civ .interface .super_kind == civ .interface .SuperKind .FILE :
468+ return self ._get_copy_input_object_task (
469+ src = civ .file , target_key = key
470+ )
471+ elif civ .interface .super_kind == civ .interface .SuperKind .VALUE :
472+ return self ._get_upload_input_content_task (
473+ content = json .dumps (civ .value ).encode ("utf-8" ),
474+ key = key ,
475+ )
476+ else :
477+ raise NotImplementedError (
478+ f"Unknown interface super kind: { civ .interface .super_kind } "
479+ )
480+
481+ def _get_create_invocation_json_task (self , * , invocation_inputs ):
482+ return self ._get_upload_input_content_task (
483+ content = json .dumps (
484+ [
485+ {
486+ "pk" : self ._job_id ,
487+ "inputs" : invocation_inputs ,
488+ "output_bucket_name" : settings .COMPONENTS_OUTPUT_BUCKET_NAME ,
489+ "output_prefix" : self ._io_prefix ,
490+ }
491+ ]
492+ ).encode ("utf-8" ),
493+ key = self ._invocation_key ,
494+ )
495+
496+ @property
497+ def _auxiliary_data_provisioning_tasks (self ):
498+ tasks = []
499+
500+ if self ._algorithm_model :
501+ tasks .append (
502+ self ._get_copy_input_object_task (
503+ src = self ._algorithm_model ,
504+ target_key = self ._algorithm_model_key ,
505+ )
506+ )
507+
508+ if self ._ground_truth :
509+ tasks .append (
510+ self ._get_copy_input_object_task (
511+ src = self ._ground_truth , target_key = self ._ground_truth_key
512+ )
513+ )
514+
515+ return tasks
397516
398517 def _with_inputs_json (self , * , input_civs ):
399518 """
@@ -411,38 +530,23 @@ def _with_inputs_json(self, *, input_civs):
411530 ),
412531 )
413532
414- def _create_invocation_json (self , * , inputs ):
415- f = io .BytesIO (
416- json .dumps (
417- [
418- {
419- "pk" : self ._job_id ,
420- "inputs" : inputs ,
421- "output_bucket_name" : settings .COMPONENTS_OUTPUT_BUCKET_NAME ,
422- "output_prefix" : self ._io_prefix ,
423- }
424- ]
425- ).encode ("utf-8" )
426- )
427- self ._s3_client .upload_fileobj (
428- f , settings .COMPONENTS_INPUT_BUCKET_NAME , self ._invocation_key
533+ @staticmethod
534+ def _get_copy_input_object_task (* , src , target_key ):
535+ return functools .partial (
536+ s3_copy ,
537+ source_bucket = src .storage .bucket .name ,
538+ source_key = src .name ,
539+ target_bucket = settings .COMPONENTS_INPUT_BUCKET_NAME ,
540+ target_key = target_key ,
429541 )
430542
431- def _provision_auxilliary_data (self ):
432- if self ._algorithm_model :
433- self ._copy_input_file (
434- src = self ._algorithm_model , dest_key = self ._algorithm_model_key
435- )
436- if self ._ground_truth :
437- self ._copy_input_file (
438- src = self ._ground_truth , dest_key = self ._ground_truth_key
439- )
440-
441- def _copy_input_file (self , * , src , dest_key ):
442- self ._s3_client .copy (
443- CopySource = {"Bucket" : src .storage .bucket .name , "Key" : src .name },
444- Bucket = settings .COMPONENTS_INPUT_BUCKET_NAME ,
445- Key = dest_key ,
543+ @staticmethod
544+ def _get_upload_input_content_task (* , content , key ):
545+ return functools .partial (
546+ s3_upload_content ,
547+ content = content ,
548+ bucket = settings .COMPONENTS_INPUT_BUCKET_NAME ,
549+ key = key ,
446550 )
447551
448552 def _get_task_return_code (self ):
0 commit comments