Skip to content

Commit d1cffbb

Browse files
committed
create protected_task function
1 parent 5de250f commit d1cffbb

File tree

3 files changed

+46
-17
lines changed

3 files changed

+46
-17
lines changed

src/main.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import asyncio
44
import logging
55
import sys
6-
import traceback
7-
from typing import Coroutine
86

97
import uvloop
108

@@ -19,18 +17,11 @@
1917
import registry as registry
2018
import utils.app as app
2119
import utils.log as log
20+
from utils.exception_handling import protected_task
2221

2322
_logger = logging.getLogger("main")
2423

2524

26-
async def protected_task(task: Coroutine[None, None, None]) -> None:
27-
try:
28-
await task
29-
except Exception:
30-
_logger.error(f"Exception with task '{task}'")
31-
_logger.error(traceback.format_exc().strip())
32-
33-
3425
async def init_plugins_services(controller_enabled: bool, executor_enabled: bool) -> None:
3526
"""Initialize the plugins services"""
3627
for plugin_name, plugin in plugins.loaded_plugins.items():
@@ -86,11 +77,11 @@ async def stop_plugins_services() -> None:
8677

8778
async def finish() -> None:
8879
"""Finish the application, making sure any exception won't impact other closing tasks"""
89-
await protected_task(http_server.wait_stop())
90-
await protected_task(monitors_loader.wait_stop())
91-
await protected_task(databases.close())
92-
await protected_task(internal_database.close())
93-
await protected_task(stop_plugins_services())
80+
await protected_task(_logger, http_server.wait_stop())
81+
await protected_task(_logger, monitors_loader.wait_stop())
82+
await protected_task(_logger, databases.close())
83+
await protected_task(_logger, internal_database.close())
84+
await protected_task(_logger, plugins.services.stop())
9485

9586

9687
async def main() -> None:

src/utils/exception_handling.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import traceback
44
from contextlib import contextmanager
5-
from typing import Generator
5+
from typing import Coroutine, Generator
66

77
from base_exception import BaseSentinelaException
88

@@ -29,3 +29,11 @@ def catch_exceptions(
2929
if error_message:
3030
logger.error(error_message)
3131
logger.info("Exception caught successfully, going on")
32+
33+
34+
async def protected_task(logger: logging.Logger, task: Coroutine[None, None, None]) -> None:
35+
try:
36+
await task
37+
except Exception:
38+
logger.error(f"Exception with task '{task}'")
39+
logger.error(traceback.format_exc().strip())

tests/utils/test_exception_handling.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from base_exception import BaseSentinelaException
88
from tests.test_utils import assert_message_in_log, assert_message_not_in_log
9-
from utils.exception_handling import catch_exceptions
9+
from utils.exception_handling import catch_exceptions, protected_task
1010

1111
pytestmark = pytest.mark.asyncio(loop_scope="session")
1212

@@ -78,3 +78,33 @@ async def error() -> None:
7878
assert logger_error_spy.call_count == 2
7979
logger_error_spy.assert_called_with("error function raised exception")
8080
logger_info_spy.assert_called_once_with("Exception caught successfully, going on")
81+
82+
83+
async def test_protected_task(caplog, mocker):
84+
"""'protected_task' should do nothing if the execution doesn't raise any errors"""
85+
logger = logging.getLogger("test_protected_task")
86+
logger_error_spy: MagicMock = mocker.spy(logger, "error")
87+
88+
async def no_error() -> None:
89+
pass
90+
91+
await protected_task(logger, no_error())
92+
93+
assert_message_not_in_log(caplog, "Exception with task")
94+
logger_error_spy.assert_not_called()
95+
96+
97+
async def test_protected_task_error(caplog, mocker):
98+
"""'protected_task' should log the exception message if an exception is raised"""
99+
logger = logging.getLogger("test_protected_task")
100+
logger_error_spy: MagicMock = mocker.spy(logger, "error")
101+
102+
async def error() -> None:
103+
raise ValueError("should be raised")
104+
105+
await protected_task(logger, error())
106+
107+
assert_message_in_log(caplog, "Exception with task")
108+
assert_message_in_log(caplog, "ValueError: should be raised")
109+
110+
assert logger_error_spy.call_count == 2

0 commit comments

Comments
 (0)