Skip to content

Commit 2c0c8ec

Browse files
authored
Merge pull request #79 from GabrielSalla/create-task-manager
Create task manager
2 parents bbf6601 + 9075cac commit 2c0c8ec

File tree

12 files changed

+576
-101
lines changed

12 files changed

+576
-101
lines changed

src/components/controller/controller.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import logging
33
import traceback
44
from datetime import datetime
5-
from typing import Any
5+
from typing import Any, Coroutine
66

77
import prometheus_client
88

9+
import components.task_manager as task_manager
910
import message_queue as message_queue
1011
import registry as registry
1112
import utils.app as app
@@ -115,7 +116,7 @@ async def _run_task(semaphore: asyncio.Semaphore, monitor: Monitor) -> None:
115116

116117
async def _create_process_task(
117118
semaphore: asyncio.Semaphore, monitor: Monitor
118-
) -> asyncio.Task[Any] | None:
119+
) -> Coroutine[Any, Any, Any] | None:
119120
"""Create a task to process the monitor"""
120121
# Instead of registering the monitor, skip if it's not registered yet
121122
# If processing a monitor that is not yet registered, the executor won't have
@@ -126,15 +127,20 @@ async def _create_process_task(
126127
return None
127128

128129
# Process monitors concurrently
129-
# Use '_run_task' to keep the semaphore lock while the monitor is being processed
130+
# Use '_run_task' to hold the semaphore while the monitor is being processed
130131
async with semaphore:
131-
return asyncio.create_task(_run_task(semaphore, monitor))
132+
return _run_task(semaphore, monitor)
132133

133134

134135
async def run() -> None:
135136
global last_loop_at
136137
global running
137138

139+
current_task = asyncio.current_task()
140+
if current_task is None:
141+
_logger.error("Could not get the current asyncio task, finishing")
142+
return
143+
138144
running = True
139145

140146
_logger.info("Controller running")
@@ -145,32 +151,27 @@ async def run() -> None:
145151
# Queue setup
146152
semaphore = asyncio.Semaphore(configs.controller_concurrency)
147153

148-
tasks: list[asyncio.Task[Any]] = []
149-
150154
while app.running():
151155
with catch_exceptions(_logger):
152156
# Wait for monitors to be ready
153157
await registry.wait_monitors_ready()
154158

155-
# Tasks cleaning
156-
tasks = [task for task in tasks if not task.done()]
157-
158159
last_loop_at = now()
159160

160161
# Run the procedures in the background
161-
procedures_task = asyncio.create_task(run_procedures())
162-
tasks.append(procedures_task)
162+
task_manager.create_task(run_procedures(), parent_task=current_task)
163163

164164
# Loop through all monitors
165165
enabled_monitors = await Monitor.get_all(Monitor.enabled.is_(True))
166166
for monitor in enabled_monitors:
167-
task = await _create_process_task(semaphore, monitor)
168-
if task is not None:
169-
tasks.append(task)
167+
coroutine = await _create_process_task(semaphore, monitor)
168+
if coroutine is not None:
169+
task_manager.create_task(coroutine, parent_task=current_task)
170170

171171
# Sleep until next scheduling decision, if necessary
172172
if not is_triggered(controller_process_schedule, last_loop_at):
173173
sleep_time = time_until_next_trigger(controller_process_schedule)
174174
await app.sleep(sleep_time)
175175

176176
_logger.info("Finishing")
177+
await task_manager.wait_for_tasks(parent_task=current_task)

src/components/executor/executor.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import datetime
66
from typing import Any
77

8+
import components.task_manager as task_manager
89
import message_queue as message_queue
910
import registry as registry
1011
import utils.app as app
@@ -40,23 +41,15 @@ async def diagnostics() -> tuple[dict[str, Any], list[str]]:
4041
return status, issues
4142

4243

43-
def count_running(tasks: list[asyncio.Task[Any]]) -> int:
44-
"""Count the number of running tasks"""
45-
return len([task for task in tasks if not task.done()])
46-
47-
48-
async def wait_for_tasks(tasks: list[asyncio.Task[Any]]) -> None:
49-
"""Wait for all running tasks to finish"""
50-
while count_running(tasks) > 0:
51-
_logger.info(f"Waiting for {count_running(tasks)} tasks to finish")
52-
await asyncio.sleep(TASKS_FINISH_CHECK_TIME)
53-
54-
5544
async def run() -> None:
5645
global last_message_at
5746
global running
5847

59-
tasks: list[asyncio.Task[Any]] = []
48+
current_task = asyncio.current_task()
49+
if current_task is None:
50+
_logger.error("Could not get the current asyncio task, finishing")
51+
return
52+
6053
runner_id = 0
6154
running = True
6255
semaphore = asyncio.Semaphore(configs.executor_concurrency)
@@ -65,9 +58,6 @@ async def run() -> None:
6558

6659
while app.running():
6760
with catch_exceptions(_logger):
68-
# Tasks cleaning
69-
tasks = [task for task in tasks if not task.done()]
70-
7161
async with semaphore:
7262
message = await message_queue.get_message()
7363

@@ -79,11 +69,10 @@ async def run() -> None:
7969

8070
runner_id += 1
8171
runner = Runner(runner_id, message)
82-
runner_task = asyncio.create_task(runner.process(semaphore))
83-
tasks.append(runner_task)
72+
task_manager.create_task(runner.process(semaphore), parent_task=current_task)
8473

8574
# Give control back to the event loop
8675
await asyncio.sleep(0)
8776

8877
_logger.info("Finishing")
89-
await wait_for_tasks(tasks)
78+
await task_manager.wait_for_tasks(parent_task=current_task)

src/components/executor/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import prometheus_client
88

9+
import components.task_manager as task_manager
910
import message_queue as message_queue
1011
import registry as registry
1112
import utils.app as app
@@ -76,7 +77,9 @@ async def process_message(
7677

7778
# Create a looping task that will keep the message not visible while it's been
7879
# handled
79-
change_visibility_task = asyncio.create_task(_change_visibility_loop(self.message))
80+
change_visibility_task = task_manager.create_task(
81+
_change_visibility_loop(self.message), parent_task=asyncio.current_task()
82+
)
8083

8184
# Protect execution from exceptions
8285
try:

src/components/monitors_loader/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
check_monitor,
44
init,
55
register_monitor,
6-
wait_stop,
6+
run,
77
)
88

99
__all__ = [
1010
"MonitorValidationError",
1111
"check_monitor",
1212
"init",
1313
"register_monitor",
14-
"wait_stop",
14+
"run",
1515
]

src/components/monitors_loader/monitors_loader.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import shutil
44
from datetime import datetime, timedelta
55
from pathlib import Path
6-
from typing import Any, Generator, cast
6+
from typing import Generator, cast
77

88
from pydantic.dataclasses import dataclass
99

@@ -27,8 +27,6 @@
2727
EARLY_LOAD_TIME = 5
2828
COOL_DOWN_TIME = 2
2929

30-
_task: asyncio.Task[Any]
31-
3230

3331
@dataclass
3432
class AdditionalFile:
@@ -289,7 +287,7 @@ async def _load_monitors(last_load_time: datetime | None) -> None:
289287
registry.monitors_pending.clear()
290288

291289

292-
async def _run() -> None:
290+
async def run() -> None:
293291
"""Monitors loading loop, loading them recurrently. Stops automatically when the app stops"""
294292
last_load_time: datetime | None = None
295293

@@ -326,21 +324,13 @@ async def _run() -> None:
326324
if time_since_last_load < COOL_DOWN_TIME:
327325
await app.sleep(COOL_DOWN_TIME - time_since_last_load)
328326

327+
_logger.info("Removing temporary monitors paths")
328+
shutil.rmtree(Path(RELATIVE_PATH) / MONITORS_LOAD_PATH, ignore_errors=True)
329+
shutil.rmtree(Path(RELATIVE_PATH) / MONITORS_PATH, ignore_errors=True)
330+
329331

330332
async def init(controller_enabled: bool) -> None:
331333
"""Load the internal monitors and sample monitors if controller is enabled, and start the
332334
monitors load task"""
333335
if controller_enabled:
334336
await _register_monitors()
335-
336-
global _task
337-
_task = asyncio.create_task(_run())
338-
339-
340-
async def wait_stop() -> None:
341-
"""Wait for the Monitors load task to finish"""
342-
global _task
343-
await _task
344-
_logger.info("Removing temporary monitors paths")
345-
shutil.rmtree(Path(RELATIVE_PATH) / MONITORS_LOAD_PATH, ignore_errors=True)
346-
shutil.rmtree(Path(RELATIVE_PATH) / MONITORS_PATH, ignore_errors=True)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .task_manager import create_task, run, wait_for_tasks
2+
3+
__all__ = [
4+
"create_task",
5+
"run",
6+
"wait_for_all_tasks",
7+
"wait_for_tasks",
8+
]
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import asyncio
2+
import logging
3+
from functools import partial
4+
from typing import Any, Coroutine
5+
6+
import utils.app as app
7+
8+
TASKS_FINISH_CHECK_TIME = 1
9+
10+
_logger = logging.getLogger("task_manager")
11+
12+
_tasks: dict[asyncio.Task[Any] | None, list[asyncio.Task[Any]]] = {}
13+
14+
15+
def _on_parent_done(parent_task: asyncio.Task[Any], task: asyncio.Task[Any]) -> None:
16+
"""Callback when parent task is done to cancel child task if it's still running"""
17+
if not task.done():
18+
_logger.error(f"Cancelling task '{task.get_name()}' as parent task is done")
19+
task.cancel()
20+
21+
22+
def create_task(
23+
coro: Coroutine[Any, Any, Any], parent_task: asyncio.Task[Any] | None = None
24+
) -> asyncio.Task[Any]:
25+
"""Create a task that will be executed in the background with an optional 'parent' attribute. If
26+
the parent task is done while the child task is running, the child task will be canceled"""
27+
task = asyncio.create_task(coro, name=coro.__name__)
28+
_tasks.setdefault(parent_task, []).append(task)
29+
30+
if parent_task is not None:
31+
parent_task.add_done_callback(partial(_on_parent_done, task=task))
32+
33+
return task
34+
35+
36+
def _clear_completed() -> None:
37+
"""Remove completed tasks from the global task list"""
38+
global _tasks
39+
_tasks = {
40+
parent: [task for task in tasks if not task.done()] for parent, tasks in _tasks.items()
41+
}
42+
_tasks = {parent: tasks for parent, tasks in _tasks.items() if len(tasks) > 0}
43+
44+
45+
async def wait_for_tasks(
46+
parent_task: asyncio.Task[Any] | None, timeout: float | None = None, cancel: bool = False
47+
) -> bool:
48+
"""Wait for all running tasks started by the parent task to finish. If all tasks finish before
49+
the timeout, the function will return True. If the timeout is 'None', the function will wait
50+
until all tasks finish. If the timeout is reached, the function will return False.
51+
If cancel is True, all pending tasks will be cancelled on timeout."""
52+
tasks = _tasks.get(parent_task, [])
53+
if len(tasks) == 0:
54+
return True
55+
56+
done, pending = await asyncio.wait(tasks, timeout=timeout)
57+
58+
if len(pending) > 0:
59+
if cancel:
60+
for task in pending:
61+
_logger.info(f"Task '{task.get_name()}' timed out")
62+
task.cancel()
63+
return False
64+
65+
return True
66+
67+
68+
async def wait_for_all_tasks(timeout: float | None = None, cancel: bool = False) -> None:
69+
"""Wait for all running tasks to finish"""
70+
for parent_task in _tasks.keys():
71+
await wait_for_tasks(parent_task=parent_task, timeout=timeout, cancel=cancel)
72+
73+
74+
def _count_running(tasks: dict[Any, list[asyncio.Task[Any]]]) -> int:
75+
"""Count the number of running tasks"""
76+
running_tasks = 0
77+
for task_list in tasks.values():
78+
running_tasks += len([task for task in task_list if not task.done()])
79+
return running_tasks
80+
81+
82+
async def _wait_to_finish(tasks: dict[Any, list[asyncio.Task[Any]]]) -> None:
83+
"""Wait for all running tasks to finish"""
84+
while True:
85+
running_tasks = _count_running(tasks)
86+
if running_tasks == 0:
87+
break
88+
_logger.info(f"Waiting for {running_tasks} tasks to finish")
89+
await asyncio.sleep(TASKS_FINISH_CHECK_TIME)
90+
91+
92+
async def run() -> None:
93+
_logger.info("Task manager running")
94+
95+
while app.running():
96+
_clear_completed()
97+
await app.sleep(60)
98+
99+
_logger.info("Finishing")
100+
await _wait_to_finish(_tasks)

src/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Still missing tests for main.py, so it's been ignored in the .coveragerc file
22

3-
import asyncio
43
import logging
54
import sys
65

@@ -10,6 +9,7 @@
109
import components.executor as executor
1110
import components.http_server as http_server
1211
import components.monitors_loader as monitors_loader
12+
import components.task_manager as task_manager
1313
import databases as databases
1414
import internal_database as internal_database
1515
import message_queue as message_queue
@@ -44,7 +44,6 @@ async def init(controller_enabled: bool, executor_enabled: bool) -> None:
4444
async def finish(controller_enabled: bool, executor_enabled: bool) -> None:
4545
"""Finish the application, making sure any exception won't impact other closing tasks"""
4646
await protected_task(_logger, http_server.wait_stop())
47-
await protected_task(_logger, monitors_loader.wait_stop())
4847
await protected_task(_logger, databases.close())
4948
await protected_task(_logger, internal_database.close())
5049
await protected_task(
@@ -69,8 +68,11 @@ async def main() -> None:
6968
"executor": executor.run,
7069
}
7170

72-
tasks = [modes[mode]() for mode in operation_modes]
73-
await asyncio.gather(*tasks)
71+
for mode in operation_modes:
72+
task_manager.create_task(modes[mode]())
73+
task_manager.create_task(monitors_loader.run())
74+
75+
await task_manager.run()
7476

7577
await finish(
7678
controller_enabled="controller" in operation_modes,

0 commit comments

Comments
 (0)