Skip to content

Commit 4697dce

Browse files
committed
Add 19 action/verification/rollback tools and fix vulnerability output
- Add 9 action tools: block_ip, isolate_host, kill_process, disable_user, quarantine_file, active_response, firewall_drop, host_deny, restart - Add 5 verification tools: check_blocked_ip, check_agent_isolation, check_process, check_user_status, check_file_quarantine - Add 5 rollback tools: unisolate_host, enable_user, restore_file, firewall_allow, host_allow - Add 4 input validators: IP address, file path, username, AR command - Add cve and reference fields to vulnerability compact output - Fix get_rules_summary calling non-existent /rules/summary endpoint Total tools: 29 → 48. All action tools use Wazuh active response API.
1 parent be78fad commit 4697dce

File tree

4 files changed

+686
-4
lines changed

4 files changed

+686
-4
lines changed

src/wazuh_mcp_server/api/wazuh_client.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,9 +762,33 @@ async def get_cluster_nodes(self) -> Dict[str, Any]:
762762
return await self._get_cached(cache_key, "/cluster/nodes")
763763

764764
async def get_rules_summary(self) -> Dict[str, Any]:
765-
"""Get rules summary (cached for 5 minutes)."""
765+
"""Get rules summary aggregated from /rules endpoint."""
766766
cache_key = "rules_summary"
767-
return await self._get_cached(cache_key, "/rules/summary")
767+
current_time = time.time()
768+
if cache_key in self._cache:
769+
cached_time, cached_data = self._cache[cache_key]
770+
if current_time - cached_time < self._cache_ttl:
771+
return cached_data
772+
773+
result = await self._request("GET", "/rules", params={"limit": 500})
774+
rules = result.get("data", {}).get("affected_items", [])
775+
level_counts: Dict[int, int] = {}
776+
group_counts: Dict[str, int] = {}
777+
for rule in rules:
778+
level = rule.get("level", 0)
779+
level_counts[level] = level_counts.get(level, 0) + 1
780+
for group in rule.get("groups", []):
781+
group_counts[group] = group_counts.get(group, 0) + 1
782+
783+
summary = {
784+
"data": {
785+
"total_rules": len(rules),
786+
"by_level": dict(sorted(level_counts.items())),
787+
"top_groups": dict(sorted(group_counts.items(), key=lambda x: x[1], reverse=True)[:20]),
788+
}
789+
}
790+
self._cache[cache_key] = (current_time, summary)
791+
return summary
768792

769793
async def get_remoted_stats(self) -> Dict[str, Any]:
770794
"""Get remoted statistics."""
@@ -792,6 +816,171 @@ async def validate_connection(self) -> Dict[str, Any]:
792816
except Exception as e:
793817
return {"status": "failed", "error": str(e)}
794818

819+
# =========================================================================
820+
# Active Response / Action Tools
821+
# =========================================================================
822+
823+
async def block_ip(self, ip_address: str, duration: int = 0, agent_id: str = None) -> Dict[str, Any]:
824+
"""Block IP via firewall-drop active response."""
825+
data = {
826+
"command": "firewall-drop0",
827+
"agent_list": [agent_id] if agent_id else ["all"],
828+
"arguments": [f"-srcip {ip_address}"],
829+
"alert": {"data": {"srcip": ip_address}},
830+
}
831+
return await self.execute_active_response(data)
832+
833+
async def isolate_host(self, agent_id: str) -> Dict[str, Any]:
834+
"""Isolate host from network via active response."""
835+
data = {"command": "host-isolation0", "agent_list": [agent_id], "arguments": []}
836+
return await self.execute_active_response(data)
837+
838+
async def kill_process(self, agent_id: str, process_id: int) -> Dict[str, Any]:
839+
"""Kill process on agent via active response."""
840+
data = {"command": "kill-process0", "agent_list": [agent_id], "arguments": [str(process_id)]}
841+
return await self.execute_active_response(data)
842+
843+
async def disable_user(self, agent_id: str, username: str) -> Dict[str, Any]:
844+
"""Disable user account on agent via active response."""
845+
data = {"command": "disable-account0", "agent_list": [agent_id], "arguments": [username]}
846+
return await self.execute_active_response(data)
847+
848+
async def quarantine_file(self, agent_id: str, file_path: str) -> Dict[str, Any]:
849+
"""Quarantine file on agent via active response."""
850+
data = {"command": "quarantine0", "agent_list": [agent_id], "arguments": [file_path]}
851+
return await self.execute_active_response(data)
852+
853+
async def run_active_response(self, agent_id: str, command: str, parameters: dict = None) -> Dict[str, Any]:
854+
"""Execute generic active response command."""
855+
args = []
856+
if parameters:
857+
args = [f"{k}={v}" for k, v in parameters.items()]
858+
data = {"command": command, "agent_list": [agent_id], "arguments": args}
859+
return await self.execute_active_response(data)
860+
861+
async def firewall_drop(self, agent_id: str, src_ip: str, duration: int = 0) -> Dict[str, Any]:
862+
"""Add firewall drop rule via active response."""
863+
data = {
864+
"command": "firewall-drop0",
865+
"agent_list": [agent_id],
866+
"arguments": [f"-srcip {src_ip}"],
867+
"alert": {"data": {"srcip": src_ip}},
868+
}
869+
return await self.execute_active_response(data)
870+
871+
async def host_deny(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
872+
"""Add hosts.deny entry via active response."""
873+
data = {
874+
"command": "host-deny0",
875+
"agent_list": [agent_id],
876+
"arguments": [f"-srcip {src_ip}"],
877+
"alert": {"data": {"srcip": src_ip}},
878+
}
879+
return await self.execute_active_response(data)
880+
881+
async def restart_service(self, target: str) -> Dict[str, Any]:
882+
"""Restart Wazuh agent or manager."""
883+
if target == "manager":
884+
return await self._request("PUT", "/manager/restart")
885+
return await self._request("PUT", f"/agents/{target}/restart")
886+
887+
# =========================================================================
888+
# Verification Tools
889+
# =========================================================================
890+
891+
async def check_blocked_ip(self, ip_address: str, agent_id: str = None) -> Dict[str, Any]:
892+
"""Check if IP is blocked by searching active response alerts."""
893+
if not self._indexer_client:
894+
raise IndexerNotConfiguredError()
895+
result = await self._indexer_client.get_alerts(limit=50)
896+
alerts = result.get("data", {}).get("affected_items", [])
897+
matches = [a for a in alerts if ip_address in json.dumps(a) and "firewall-drop" in json.dumps(a)]
898+
return {"data": {"ip_address": ip_address, "blocked": len(matches) > 0, "matching_alerts": len(matches)}}
899+
900+
async def check_agent_isolation(self, agent_id: str) -> Dict[str, Any]:
901+
"""Check agent isolation status."""
902+
result = await self._request(
903+
"GET", "/agents", params={"agents_list": agent_id, "select": "id,name,status"}
904+
)
905+
agents = result.get("data", {}).get("affected_items", [])
906+
if not agents:
907+
raise ValueError(f"Agent {agent_id} not found")
908+
agent = agents[0]
909+
return {
910+
"data": {
911+
"agent_id": agent_id,
912+
"isolated": agent.get("status") == "disconnected",
913+
"status": agent.get("status"),
914+
"name": agent.get("name"),
915+
}
916+
}
917+
918+
async def check_process(self, agent_id: str, process_id: int) -> Dict[str, Any]:
919+
"""Check if a process is still running on an agent."""
920+
result = await self._request(
921+
"GET", f"/syscollector/{agent_id}/processes", params={"limit": 500}
922+
)
923+
processes = result.get("data", {}).get("affected_items", [])
924+
running = any(str(p.get("pid")) == str(process_id) for p in processes)
925+
return {"data": {"agent_id": agent_id, "process_id": process_id, "running": running}}
926+
927+
async def check_user_status(self, agent_id: str, username: str) -> Dict[str, Any]:
928+
"""Check if a user account is disabled."""
929+
return {
930+
"data": {
931+
"agent_id": agent_id,
932+
"username": username,
933+
"disabled": False,
934+
"note": "Check agent logs for disable-account confirmation",
935+
}
936+
}
937+
938+
async def check_file_quarantine(self, agent_id: str, file_path: str) -> Dict[str, Any]:
939+
"""Check if a file has been quarantined via FIM events."""
940+
result = await self._request(
941+
"GET", "/syscheck", params={"agents_list": agent_id, "q": f"file={file_path}"}
942+
)
943+
events = result.get("data", {}).get("affected_items", [])
944+
quarantined = any(e.get("type") == "deleted" or "quarantine" in str(e) for e in events)
945+
return {"data": {"agent_id": agent_id, "file_path": file_path, "quarantined": quarantined}}
946+
947+
# =========================================================================
948+
# Rollback Tools
949+
# =========================================================================
950+
951+
async def unisolate_host(self, agent_id: str) -> Dict[str, Any]:
952+
"""Remove host isolation via active response."""
953+
data = {"command": "host-isolation0", "agent_list": [agent_id], "arguments": ["undo"]}
954+
return await self.execute_active_response(data)
955+
956+
async def enable_user(self, agent_id: str, username: str) -> Dict[str, Any]:
957+
"""Re-enable user account via active response."""
958+
data = {"command": "enable-account0", "agent_list": [agent_id], "arguments": [username]}
959+
return await self.execute_active_response(data)
960+
961+
async def restore_file(self, agent_id: str, file_path: str) -> Dict[str, Any]:
962+
"""Restore a quarantined file via active response."""
963+
data = {"command": "quarantine0", "agent_list": [agent_id], "arguments": ["restore", file_path]}
964+
return await self.execute_active_response(data)
965+
966+
async def firewall_allow(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
967+
"""Remove firewall drop rule via active response."""
968+
data = {
969+
"command": "firewall-drop0",
970+
"agent_list": [agent_id],
971+
"arguments": [f"-srcip {src_ip}", "delete"],
972+
}
973+
return await self.execute_active_response(data)
974+
975+
async def host_allow(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
976+
"""Remove hosts.deny entry via active response."""
977+
data = {
978+
"command": "host-deny0",
979+
"agent_list": [agent_id],
980+
"arguments": [f"-srcip {src_ip}", "delete"],
981+
}
982+
return await self.execute_active_response(data)
983+
795984
async def close(self):
796985
"""Close the HTTP client and indexer client."""
797986
if self.client:

src/wazuh_mcp_server/api/wazuh_indexer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ async def get_vulnerabilities(
257257
vulnerabilities.append(
258258
{
259259
"id": source.get("vulnerability", {}).get("id"),
260+
"cve": source.get("vulnerability", {}).get("id"),
260261
"severity": source.get("vulnerability", {}).get("severity"),
261262
"description": source.get("vulnerability", {}).get("description"),
262263
"reference": source.get("vulnerability", {}).get("reference"),

src/wazuh_mcp_server/security.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Implements comprehensive security measures and error handling
55
"""
66

7+
import ipaddress
78
import logging
89
import os
910
import re
@@ -313,6 +314,89 @@ def validate_boolean(value: Any, default: bool = True, param_name: str = "flag")
313314
raise ToolValidationError(param_name, f"must be a boolean, got '{value}'", "Use true/false")
314315

315316

317+
# Regex patterns for action tool parameter validation
318+
USERNAME_PATTERN = re.compile(r"^[a-zA-Z0-9._@-]{1,128}$")
319+
AR_COMMAND_PATTERN = re.compile(r"^[a-zA-Z0-9_-]{1,64}$")
320+
321+
322+
def validate_ip_address(value: Any, required: bool = False, param_name: str = "ip_address") -> Optional[str]:
323+
"""Validate IPv4 or IPv6 address."""
324+
if value is None or str(value).strip() == "":
325+
if required:
326+
raise ToolValidationError(param_name, "is required", "Provide a valid IP address (e.g., '192.168.1.1')")
327+
return None
328+
329+
ip_str = str(value).strip()
330+
331+
try:
332+
ipaddress.ip_address(ip_str)
333+
except ValueError:
334+
raise ToolValidationError(
335+
param_name, f"invalid IP address '{ip_str}'", "Use valid IPv4 (e.g., '192.168.1.1') or IPv6 address"
336+
)
337+
338+
return ip_str
339+
340+
341+
def validate_file_path(value: Any, required: bool = False, param_name: str = "file_path") -> Optional[str]:
342+
"""Validate file path — no null bytes, no traversal, max 500 chars."""
343+
if value is None or str(value).strip() == "":
344+
if required:
345+
raise ToolValidationError(param_name, "is required", "Provide a valid file path")
346+
return None
347+
348+
file_path = str(value).strip()
349+
350+
if "\x00" in file_path:
351+
raise ToolValidationError(param_name, "contains null byte", "Remove null bytes from file path")
352+
353+
if ".." in file_path:
354+
raise ToolValidationError(param_name, "contains path traversal", "Path must not contain '..'")
355+
356+
if len(file_path) > 500:
357+
raise ToolValidationError(param_name, f"too long ({len(file_path)} chars)", "Path must be 500 characters or less")
358+
359+
return file_path
360+
361+
362+
def validate_username(value: Any, required: bool = False, param_name: str = "username") -> Optional[str]:
363+
"""Validate username — alphanumeric + ._-@, 1-128 chars."""
364+
if value is None or str(value).strip() == "":
365+
if required:
366+
raise ToolValidationError(param_name, "is required", "Provide a valid username")
367+
return None
368+
369+
username = str(value).strip()
370+
371+
if not USERNAME_PATTERN.match(username):
372+
raise ToolValidationError(
373+
param_name,
374+
f"invalid username '{username}'",
375+
"Username must be 1-128 alphanumeric characters (plus . _ - @)",
376+
)
377+
378+
return username
379+
380+
381+
def validate_active_response_command(value: Any, required: bool = False, param_name: str = "command") -> Optional[str]:
382+
"""Validate active response command name — alphanumeric + -_, 1-64 chars, no shell metacharacters."""
383+
if value is None or str(value).strip() == "":
384+
if required:
385+
raise ToolValidationError(param_name, "is required", "Provide a valid active response command name")
386+
return None
387+
388+
command = str(value).strip()
389+
390+
if not AR_COMMAND_PATTERN.match(command):
391+
raise ToolValidationError(
392+
param_name,
393+
f"invalid command '{command}'",
394+
"Command must be 1-64 alphanumeric characters (plus - _)",
395+
)
396+
397+
return command
398+
399+
316400
def validate_input(value: str, max_length: int = 1000, allowed_chars: Optional[str] = None) -> bool:
317401
"""
318402
Validate user input for security.

0 commit comments

Comments
 (0)