1+ import asyncio
2+ import multiprocessing
3+ import logging
4+
5+ logger = logging .getLogger (__name__ )
6+
7+
8+ class AsyncTaskQueue :
9+ poll_interval = 0.1
10+
11+ def __init__ (self , num_workers = None ):
12+ if num_workers is None :
13+ num_workers = multiprocessing .cpu_count ()
14+
15+ self .tasks = []
16+ self .tasks_todo = []
17+ self .results = []
18+ self .num_workers = num_workers
19+
20+ @property
21+ def is_full (self ):
22+ return len (self .tasks ) >= self .num_workers
23+
24+ def start_task (self , coroutine ):
25+ task = asyncio .create_task (coroutine )
26+ self .tasks .append (task )
27+ task .add_done_callback (self .done_callback )
28+
29+ async def add (self , coroutine , wait = True ):
30+ if wait :
31+ while self .is_full :
32+ await asyncio .sleep (self .poll_interval )
33+ elif self .is_full :
34+ # store for later (end of another task)
35+ self .tasks_todo .append (coroutine )
36+ return
37+
38+ self .start_task (coroutine )
39+
40+ def done_callback (self , task ):
41+ logger .info (f"finished { task } " )
42+
43+ self .tasks .remove (task )
44+ self .results .append (task .result ())
45+
46+ if len (self .tasks_todo ):
47+ coroutine = self .tasks_todo .pop ()
48+ self .start_task (coroutine )
49+
50+ async def finish (self ):
51+ while len (self .tasks ) > 0 :
52+ await asyncio .sleep (self .num_workers )
53+
54+ return self .results
0 commit comments