|
14 | 14 | import aioboto3 |
15 | 15 | import boto3 |
16 | 16 | import botocore |
| 17 | +from asgiref.sync import async_to_sync |
17 | 18 | from django.conf import settings |
18 | 19 | from django.core.exceptions import SuspiciousFileOperation, ValidationError |
19 | 20 | from django.db import transaction |
@@ -192,16 +193,14 @@ def __init__( |
192 | 193 | def provision(self, *, input_civs, input_prefixes): |
193 | 194 | # We cannot run everything async as it requires database access. |
194 | 195 | # So first we gather the definitions of the async tasks that |
195 | | - # need to be run, then execute them in a new asyncio loop. |
| 196 | + # need to be run, then execute them in the event loop for |
| 197 | + # the current thread using @async_to_sync. |
196 | 198 | provisioning_task_definitions = ( |
197 | 199 | self._get_provisioning_task_definitions( |
198 | 200 | input_civs=input_civs, input_prefixes=input_prefixes |
199 | 201 | ) |
200 | 202 | ) |
201 | | - |
202 | | - asyncio.run( |
203 | | - self._provision(task_definitions=provisioning_task_definitions) |
204 | | - ) |
| 203 | + self._provision(task_definitions=provisioning_task_definitions) |
205 | 204 |
|
206 | 205 | @abstractmethod |
207 | 206 | def execute(self): ... |
@@ -410,24 +409,20 @@ def _get_key_and_relative_path(self, *, civ, input_prefixes): |
410 | 409 |
|
411 | 410 | return key, relative_path |
412 | 411 |
|
| 412 | + @async_to_sync |
413 | 413 | async def _provision(self, *, task_definitions): |
414 | 414 | semaphore = asyncio.Semaphore(CONCURRENCY) |
415 | 415 | session = aioboto3.Session() |
416 | 416 |
|
417 | | - provisioning_tasks = set() |
418 | | - |
419 | | - for task_definition in task_definitions: |
420 | | - aio_task = asyncio.create_task( |
421 | | - task_definition.method( |
422 | | - **task_definition.kwargs, |
423 | | - semaphore=semaphore, |
424 | | - session=session, |
| 417 | + async with asyncio.TaskGroup() as task_group: |
| 418 | + for task_definition in task_definitions: |
| 419 | + task_group.create_task( |
| 420 | + task_definition.method( |
| 421 | + **task_definition.kwargs, |
| 422 | + semaphore=semaphore, |
| 423 | + session=session, |
| 424 | + ) |
425 | 425 | ) |
426 | | - ) |
427 | | - provisioning_tasks.add(aio_task) |
428 | | - aio_task.add_done_callback(provisioning_tasks.discard) |
429 | | - |
430 | | - await asyncio.gather(*provisioning_tasks) |
431 | 426 |
|
432 | 427 | def _get_provisioning_task_definitions( |
433 | 428 | self, *, input_civs, input_prefixes |
|
0 commit comments