11import asyncio
2+ import functools
23import io
34import json
45import logging
1516import boto3
1617import botocore
1718from asgiref .sync import async_to_sync
19+ from botocore .config import Config
1820from django .conf import settings
1921from django .core .exceptions import SuspiciousFileOperation , ValidationError
2022from django .db import transaction
4446logger = logging .getLogger (__name__ )
4547
4648MAX_SPOOL_SIZE = 1_000_000_000 # 1GB
47- CONCURRENCY = 128
49+
50+ CONCURRENCY = 50
51+ BOTO_CONFIG = Config (max_pool_connections = 120 )
4852
4953
5054class JobParams (NamedTuple ):
@@ -124,41 +128,34 @@ def list_and_delete_objects_from_prefix(*, s3_client, bucket, prefix):
124128 )
125129
126130
127- class AsyncTaskDefinition (NamedTuple ):
128- # semaphore and session (aioboto3) will get added to
129- # the kwargs by the caller
130- method : callable
131- kwargs : dict
132-
133-
134131async def s3_copy (
135- * , source_bucket , source_key , target_bucket , target_key , semaphore , session
132+ * ,
133+ source_bucket ,
134+ source_key ,
135+ target_bucket ,
136+ target_key ,
137+ semaphore ,
138+ s3_client ,
136139):
137140 async with semaphore :
138- async with session .client (
139- "s3" , endpoint_url = settings .AWS_S3_ENDPOINT_URL
140- ) as s3_client :
141- await s3_client .copy (
142- CopySource = {"Bucket" : source_bucket , "Key" : source_key },
143- Bucket = target_bucket ,
144- Key = target_key ,
145- )
141+ await s3_client .copy (
142+ CopySource = {"Bucket" : source_bucket , "Key" : source_key },
143+ Bucket = target_bucket ,
144+ Key = target_key ,
145+ )
146146
147147
148- async def s3_upload_content (* , content , bucket , key , semaphore , session ):
148+ async def s3_upload_content (* , content , bucket , key , semaphore , s3_client ):
149149 async with semaphore :
150- async with session .client (
151- "s3" , endpoint_url = settings .AWS_S3_ENDPOINT_URL
152- ) as s3_client :
153- with io .BytesIO () as f :
154- f .write (content )
155- f .seek (0 )
150+ with io .BytesIO () as f :
151+ f .write (content )
152+ f .seek (0 )
156153
157- await s3_client .upload_fileobj (
158- Fileobj = f ,
159- Bucket = bucket ,
160- Key = key ,
161- )
154+ await s3_client .upload_fileobj (
155+ Fileobj = f ,
156+ Bucket = bucket ,
157+ Key = key ,
158+ )
162159
163160
164161class Executor (ABC ):
@@ -192,15 +189,13 @@ def __init__(
192189
193190 def provision (self , * , input_civs , input_prefixes ):
194191 # We cannot run everything async as it requires database access.
195- # So first we gather the definitions of the async tasks that
196- # need to be run, then execute them in the event loop for
197- # the current thread using @async_to_sync.
198- provisioning_task_definitions = (
199- self ._get_provisioning_task_definitions (
200- input_civs = input_civs , input_prefixes = input_prefixes
201- )
192+ # So first we gather the async tasks that need to be run,
193+ # then execute them in the event loop for the current thread
194+ # using a method wrapped in @async_to_sync.
195+ provisioning_tasks = self ._get_provisioning_tasks (
196+ input_civs = input_civs , input_prefixes = input_prefixes
202197 )
203- self ._provision (task_definitions = provisioning_task_definitions )
198+ self ._provision (tasks = provisioning_tasks )
204199
205200 @abstractmethod
206201 def execute (self ): ...
@@ -410,40 +405,32 @@ def _get_key_and_relative_path(self, *, civ, input_prefixes):
410405 return key , relative_path
411406
412407 @async_to_sync
413- async def _provision (self , * , task_definitions ):
408+ async def _provision (self , * , tasks ):
414409 semaphore = asyncio .Semaphore (CONCURRENCY )
415410 session = aioboto3 .Session ()
416411
417- async with asyncio .TaskGroup () as task_group :
418- for task_definition in task_definitions :
419- task_group .create_task (
420- task_definition .method (
421- ** task_definition .kwargs ,
422- semaphore = semaphore ,
423- session = session ,
412+ async with session .client (
413+ "s3" , endpoint_url = settings .AWS_S3_ENDPOINT_URL , config = BOTO_CONFIG
414+ ) as s3_client :
415+ async with asyncio .TaskGroup () as task_group :
416+ for task in tasks :
417+ task_group .create_task (
418+ task (
419+ semaphore = semaphore ,
420+ s3_client = s3_client ,
421+ )
424422 )
425- )
426423
427- def _get_provisioning_task_definitions (
428- self , * , input_civs , input_prefixes
429- ):
430- input_provisioning_task_definitions = (
431- self ._get_input_provisioning_task_definitions (
432- input_civs = input_civs , input_prefixes = input_prefixes
433- )
434- )
435- auxiliary_data_provisioning_task_definitions = (
436- self ._get_auxiliary_data_provisioning_task_definitions ()
424+ def _get_provisioning_tasks (self , * , input_civs , input_prefixes ):
425+ input_provisioning_tasks = self ._get_input_provisioning_tasks (
426+ input_civs = input_civs , input_prefixes = input_prefixes
437427 )
438428
439429 return (
440- input_provisioning_task_definitions
441- + auxiliary_data_provisioning_task_definitions
430+ input_provisioning_tasks + self ._auxiliary_data_provisioning_tasks
442431 )
443432
444- def _get_input_provisioning_task_definitions (
445- self , * , input_civs , input_prefixes
446- ):
433+ def _get_input_provisioning_tasks (self , * , input_civs , input_prefixes ):
447434 invocation_inputs = []
448435
449436 tasks = []
@@ -453,9 +440,7 @@ def _get_input_provisioning_task_definitions(
453440 civ = civ , input_prefixes = input_prefixes
454441 )
455442
456- tasks .append (
457- self ._get_civ_input_provisioning_task_definition (civ , key )
458- )
443+ tasks .append (self ._get_civ_input_provisioning_task (civ , key ))
459444
460445 invocation_inputs .append (
461446 {
@@ -467,24 +452,24 @@ def _get_input_provisioning_task_definitions(
467452 )
468453
469454 tasks .append (
470- self ._get_create_invocation_json_task_definition (
455+ self ._get_create_invocation_json_task (
471456 invocation_inputs = invocation_inputs
472457 )
473458 )
474459
475460 return tasks
476461
477- def _get_civ_input_provisioning_task_definition (self , civ , key ):
462+ def _get_civ_input_provisioning_task (self , civ , key ):
478463 if civ .interface .super_kind == civ .interface .SuperKind .IMAGE :
479- return self ._get_copy_input_object_task_definition (
464+ return self ._get_copy_input_object_task (
480465 src = civ .image_file , target_key = key
481466 )
482467 elif civ .interface .super_kind == civ .interface .SuperKind .FILE :
483- return self ._get_copy_input_object_task_definition (
468+ return self ._get_copy_input_object_task (
484469 src = civ .file , target_key = key
485470 )
486471 elif civ .interface .super_kind == civ .interface .SuperKind .VALUE :
487- return self ._get_upload_input_content_task_definition (
472+ return self ._get_upload_input_content_task (
488473 content = json .dumps (civ .value ).encode ("utf-8" ),
489474 key = key ,
490475 )
@@ -493,10 +478,8 @@ def _get_civ_input_provisioning_task_definition(self, civ, key):
493478 f"Unknown interface super kind: { civ .interface .super_kind } "
494479 )
495480
496- def _get_create_invocation_json_task_definition (
497- self , * , invocation_inputs
498- ):
499- return self ._get_upload_input_content_task_definition (
481+ def _get_create_invocation_json_task (self , * , invocation_inputs ):
482+ return self ._get_upload_input_content_task (
500483 content = json .dumps (
501484 [
502485 {
@@ -510,20 +493,21 @@ def _get_create_invocation_json_task_definition(
510493 key = self ._invocation_key ,
511494 )
512495
513- def _get_auxiliary_data_provisioning_task_definitions (self ):
496+ @property
497+ def _auxiliary_data_provisioning_tasks (self ):
514498 tasks = []
515499
516500 if self ._algorithm_model :
517501 tasks .append (
518- self ._get_copy_input_object_task_definition (
502+ self ._get_copy_input_object_task (
519503 src = self ._algorithm_model ,
520504 target_key = self ._algorithm_model_key ,
521505 )
522506 )
523507
524508 if self ._ground_truth :
525509 tasks .append (
526- self ._get_copy_input_object_task_definition (
510+ self ._get_copy_input_object_task (
527511 src = self ._ground_truth , target_key = self ._ground_truth_key
528512 )
529513 )
@@ -547,26 +531,22 @@ def _with_inputs_json(self, *, input_civs):
547531 )
548532
549533 @staticmethod
550- def _get_copy_input_object_task_definition (* , src , target_key ):
551- return AsyncTaskDefinition (
552- method = s3_copy ,
553- kwargs = {
554- "source_bucket" : src .storage .bucket .name ,
555- "source_key" : src .name ,
556- "target_bucket" : settings .COMPONENTS_INPUT_BUCKET_NAME ,
557- "target_key" : target_key ,
558- },
534+ def _get_copy_input_object_task (* , src , target_key ):
535+ return functools .partial (
536+ s3_copy ,
537+ source_bucket = src .storage .bucket .name ,
538+ source_key = src .name ,
539+ target_bucket = settings .COMPONENTS_INPUT_BUCKET_NAME ,
540+ target_key = target_key ,
559541 )
560542
561543 @staticmethod
562- def _get_upload_input_content_task_definition (* , content , key ):
563- return AsyncTaskDefinition (
564- method = s3_upload_content ,
565- kwargs = {
566- "content" : content ,
567- "bucket" : settings .COMPONENTS_INPUT_BUCKET_NAME ,
568- "key" : key ,
569- },
544+ def _get_upload_input_content_task (* , content , key ):
545+ return functools .partial (
546+ s3_upload_content ,
547+ content = content ,
548+ bucket = settings .COMPONENTS_INPUT_BUCKET_NAME ,
549+ key = key ,
570550 )
571551
572552 def _get_task_return_code (self ):
0 commit comments