Skip to content

Commit 0bb5ce2

Browse files
authored
Merge pull request #55 from GabrielSalla/add-message-validation-executor-handlers
Add message validation to executor handlers
2 parents 7bcc724 + 6d0d754 commit 0bb5ce2

File tree

17 files changed

+232
-96
lines changed

17 files changed

+232
-96
lines changed

docs/plugins.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,22 @@ Actions are used as custom behaviors to requests received by sentinela. If senti
2323

2424
Actions must have the following signature:
2525
```python
26-
async def action_name(message_payload: dict[Any, Any]):
26+
from data_models.request_payload import RequestPayload
27+
28+
29+
async def action_name(message_payload: RequestPayload):
2730
```
2831

32+
The `RequestPayload` object contains the action name and the parameters sent by the request. The parameters will vary depending on the action.
33+
2934
An example of the action call made by Sentinela is:
3035
```python
31-
await plugin.my_plugin.actions.action_name({"key": "value"})
36+
from data_models.request_payload import RequestPayload
37+
38+
39+
await plugin.my_plugin.actions.action_name(
40+
RequestPayload(action="my_plugin.action_name", params={"key": "value"})
41+
)
3242
```
3343

3444
## Notifications

src/commands/requests.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def alert_acknowledge(alert_id: int) -> None:
4040
type="request",
4141
payload={
4242
"action": "alert_acknowledge",
43-
"target_id": alert_id,
43+
"params": {"target_id": alert_id},
4444
},
4545
)
4646

@@ -51,7 +51,7 @@ async def alert_lock(alert_id: int) -> None:
5151
type="request",
5252
payload={
5353
"action": "alert_lock",
54-
"target_id": alert_id,
54+
"params": {"target_id": alert_id},
5555
},
5656
)
5757

@@ -62,7 +62,7 @@ async def alert_solve(alert_id: int) -> None:
6262
type="request",
6363
payload={
6464
"action": "alert_solve",
65-
"target_id": alert_id,
65+
"params": {"target_id": alert_id},
6666
},
6767
)
6868

@@ -73,6 +73,6 @@ async def issue_drop(issue_id: int) -> None:
7373
type="request",
7474
payload={
7575
"action": "issue_drop",
76-
"target_id": issue_id,
76+
"params": {"target_id": issue_id},
7777
},
7878
)

src/components/executor/monitor_handler.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import asyncio
2+
import json
23
import logging
34
import traceback
45
from datetime import datetime
5-
from typing import Any, cast
6+
from typing import Any, Literal, cast
67

78
import prometheus_client
9+
from pydantic import ValidationError
810

911
import registry as registry
1012
from base_exception import BaseSentinelaException
13+
from data_models.process_monitor_payload import ProcessMonitorPayload
1114
from internal_database import get_session
1215
from models import Alert, Issue, Monitor
1316
from utils.async_tools import do_concurrently
@@ -279,7 +282,7 @@ async def _alerts_routine(monitor: Monitor) -> None:
279282
await do_concurrently(*[alert.update() for alert in monitor.active_alerts])
280283

281284

282-
async def _run_routines(monitor: Monitor, tasks: list[str]) -> None:
285+
async def _run_routines(monitor: Monitor, tasks: list[Literal["search", "update"]]) -> None:
283286
"""Run all routines for a monitor, based on a list of tasks"""
284287
# Monitor instrumentation metrics
285288
prometheus_labels = {
@@ -321,9 +324,16 @@ async def _run_routines(monitor: Monitor, tasks: list[str]) -> None:
321324
async def run(message: dict[Any, Any]) -> None:
322325
"""Process a message with type 'process_monitor', loading the monitor and executing it's
323326
routines, while also detecting errors and reporting them accordingly"""
324-
message_payload = message["payload"]
327+
try:
328+
message_payload = ProcessMonitorPayload(**message["payload"])
329+
except KeyError:
330+
_logger.error(f"Message '{json.dumps(message)}' missing 'payload' field")
331+
return
332+
except ValidationError as e:
333+
_logger.error(f"Invalid payload: {e}")
334+
return
325335

326-
monitor_id = message_payload["monitor_id"]
336+
monitor_id = message_payload.monitor_id
327337
monitor = await Monitor.get_by_id(monitor_id)
328338
if monitor is None:
329339
_logger.error(f"Monitor {monitor_id} not found. Skipping message")
@@ -340,17 +350,17 @@ async def run(message: dict[Any, Any]) -> None:
340350
"monitor_name": monitor.name,
341351
}
342352

343-
try:
344-
monitor_running = prometheus_monitor_running.labels(**prometheus_labels)
345-
monitor_running.inc()
353+
monitor_running = prometheus_monitor_running.labels(**prometheus_labels)
354+
monitor_running.inc()
346355

356+
try:
347357
monitor.set_running(True)
348358
await monitor.save()
349359

350360
monitor_execution_time = prometheus_monitor_execution_time.labels(**prometheus_labels)
351361
with monitor_execution_time.time():
352362
await asyncio.wait_for(
353-
_run_routines(monitor, message_payload["tasks"]), monitor.options.execution_timeout
363+
_run_routines(monitor, message_payload.tasks), monitor.options.execution_timeout
354364
)
355365
except asyncio.TimeoutError:
356366
monitor_timeout_count = prometheus_monitor_timeout_count.labels(**prometheus_labels)

src/components/executor/reaction_handler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77

88
import prometheus_client
9+
from pydantic import ValidationError
910

1011
import registry as registry
1112
from base_exception import BaseSentinelaException
@@ -35,7 +36,15 @@
3536
async def run(message: dict[Any, Any]) -> None:
3637
"""Process a message with type 'event' using the monitor's defined list of reactions for the
3738
event. The execution timeout is for each function individually"""
38-
event_payload = EventPayload(**message["payload"])
39+
try:
40+
event_payload = EventPayload(**message["payload"])
41+
except KeyError:
42+
_logger.error(f"Message '{json.dumps(message)}' missing 'payload' field")
43+
return
44+
except ValidationError as e:
45+
_logger.error(f"Invalid payload: {e}")
46+
return
47+
3948
monitor_id = event_payload.event_source_monitor_id
4049
event_name = event_payload.event_name
4150

src/components/executor/request_handler.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from typing import Any, Callable, Coroutine, cast
66

77
import prometheus_client
8+
from pydantic import ValidationError
89

910
import plugins
1011
import registry as registry
1112
from base_exception import BaseSentinelaException
1213
from configs import configs
14+
from data_models.request_payload import RequestPayload
1315
from models import Alert, Issue
1416

1517
_logger = logging.getLogger("request_handler")
@@ -31,9 +33,9 @@
3133
)
3234

3335

34-
async def alert_acknowledge(message_payload: dict[Any, Any]) -> None:
36+
async def alert_acknowledge(message_payload: RequestPayload) -> None:
3537
"""Acknowledge an alert"""
36-
alert_id = message_payload["target_id"]
38+
alert_id = message_payload.params["target_id"]
3739
alert = await Alert.get_by_id(alert_id)
3840
if alert is None:
3941
_logger.info(f"Alert '{alert_id}' not found")
@@ -42,9 +44,9 @@ async def alert_acknowledge(message_payload: dict[Any, Any]) -> None:
4244
await alert.acknowledge()
4345

4446

45-
async def alert_lock(message_payload: dict[Any, Any]) -> None:
47+
async def alert_lock(message_payload: RequestPayload) -> None:
4648
"""Lock an alert"""
47-
alert_id = message_payload["target_id"]
49+
alert_id = message_payload.params["target_id"]
4850
alert = await Alert.get_by_id(alert_id)
4951
if alert is None:
5052
_logger.info(f"Alert '{alert_id}' not found")
@@ -53,9 +55,9 @@ async def alert_lock(message_payload: dict[Any, Any]) -> None:
5355
await alert.lock()
5456

5557

56-
async def alert_solve(message_payload: dict[Any, Any]) -> None:
58+
async def alert_solve(message_payload: RequestPayload) -> None:
5759
"""Solve all alert's issues"""
58-
alert_id = message_payload["target_id"]
60+
alert_id = message_payload.params["target_id"]
5961
alert = await Alert.get_by_id(alert_id)
6062
if alert is None:
6163
_logger.info(f"Alert '{alert_id}' not found")
@@ -64,9 +66,9 @@ async def alert_solve(message_payload: dict[Any, Any]) -> None:
6466
await alert.solve_issues()
6567

6668

67-
async def issue_drop(message_payload: dict[Any, Any]) -> None:
69+
async def issue_drop(message_payload: RequestPayload) -> None:
6870
"""Drop an issue"""
69-
issue_id = message_payload["target_id"]
71+
issue_id = message_payload.params["target_id"]
7072
issue = await Issue.get_by_id(issue_id)
7173
if issue is None:
7274
_logger.info(f"Issue '{issue_id}' not found")
@@ -83,7 +85,7 @@ async def issue_drop(message_payload: dict[Any, Any]) -> None:
8385
}
8486

8587

86-
def get_action(action_name: str) -> Callable[[dict[Any, Any]], Coroutine[Any, Any, None]] | None:
88+
def get_action(action_name: str) -> Callable[[RequestPayload], Coroutine[Any, Any, None]] | None:
8789
"""Get the action function by its name, checking if it is a plugin action"""
8890
if action_name.startswith("plugin."):
8991
plugin_name, action_name = action_name.split(".")[1:3]
@@ -103,31 +105,41 @@ def get_action(action_name: str) -> Callable[[dict[Any, Any]], Coroutine[Any, An
103105
_logger.warning(f"Action '{plugin_name}.{action_name}' unknown")
104106
return None
105107

106-
return cast(Callable[[dict[Any, Any]], Coroutine[Any, Any, None]], action)
108+
return cast(Callable[[RequestPayload], Coroutine[Any, Any, None]], action)
107109

108110
return actions.get(action_name)
109111

110112

111113
async def run(message: dict[Any, Any]) -> None:
112114
"""Process a received request"""
113-
message_payload = message["payload"]
114-
action_name = message_payload["action"]
115+
try:
116+
message_payload = RequestPayload(**message["payload"])
117+
except KeyError:
118+
_logger.error(f"Message '{json.dumps(message)}' missing 'payload' field")
119+
return
120+
except ValidationError as e:
121+
_logger.error(f"Invalid payload: {e}")
122+
return
123+
124+
action_name = message_payload.action
115125

116126
action = get_action(action_name)
117127

118128
if action is None:
119-
_logger.warning(f"Got request with unknown action '{json.dumps(message_payload)}'")
129+
_logger.warning(
130+
f"Got request with unknown action '{json.dumps(message_payload.to_dict())}'"
131+
)
120132
return
121133

122134
try:
123135
with prometheus_request_execution_time.labels(action_name=action_name).time():
124136
await asyncio.wait_for(action(message_payload), configs.executor_request_timeout)
125137
except asyncio.TimeoutError:
126138
prometheus_request_timeout_count.labels(action_name=action_name).inc()
127-
_logger.error(f"Timed out executing request '{json.dumps(message_payload)}'")
139+
_logger.error(f"Timed out executing request '{json.dumps(message_payload.to_dict())}'")
128140
except BaseSentinelaException as e:
129141
raise e
130142
except Exception:
131143
prometheus_request_error_count.labels(action_name=action_name).inc()
132-
_logger.error(f"Error executing request '{json.dumps(message_payload)}'")
144+
_logger.error(f"Error executing request '{json.dumps(message_payload.to_dict())}'")
133145
_logger.error(traceback.format_exc().strip())
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .process_monitor_payload import ProcessMonitorPayload
2+
3+
__all__ = ["ProcessMonitorPayload"]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Literal
2+
3+
from pydantic.dataclasses import dataclass
4+
5+
6+
@dataclass
7+
class ProcessMonitorPayload:
8+
monitor_id: int
9+
tasks: list[Literal["search", "update"]]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .request_payload import RequestPayload
2+
3+
__all__ = ["RequestPayload"]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any
2+
3+
from pydantic.dataclasses import dataclass
4+
5+
6+
@dataclass
7+
class RequestPayload:
8+
action: str
9+
params: dict[str, Any]
10+
11+
def to_dict(self) -> dict[str, Any]:
12+
return {
13+
field: value
14+
for field in self.__dataclass_fields__
15+
if (value := getattr(self, field)) is not None
16+
}

src/plugins/slack/services/pattern_match.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def resend_notifications(
5353
type="request",
5454
payload={
5555
"action": "plugin.slack.resend_notifications",
56-
"slack_channel": context["channel"],
56+
"params": {"slack_channel": context["channel"]},
5757
},
5858
)
5959

0 commit comments

Comments
 (0)