Skip to content

Commit a838c5b

Browse files
committed
Fix 17 bugs, gaps, and performance issues found in end-to-end audit
Bugs: - Fix _run_sync catching its own RuntimeError safety guard - Fix async __contains__ (Python 'in' operator never awaits) - Fix auth bypass on root / endpoint allowing unauthenticated access - Remove dead create_auth_endpoints code; fix /auth/token to use auth_manager - Fix search_security_events silently ignoring query parameter - Fix cleanup_expired missing timeout_minutes parameter - Register monitoring middleware for request tracking - Fix /metrics using default registry instead of custom REGISTRY Gaps: - Sync time range values: add 12h/1d across security.py and wazuh_client.py - Add 'pending' to agent status enum in tool schema Enhancements: - Migrate MCPResponse from deprecated Pydantic v1 dict() to model_dump() - Register security middleware for response headers - Expand test suite from 10 to 33 tests covering new fixes Performance: - Refactor get_alerts to use _search helper with retry logic - Add max size guard to unbounded _initialized_sessions dict - Replace O(n*m) json.dumps per-alert with recursive dict search
1 parent a9636fb commit a838c5b

File tree

6 files changed

+308
-186
lines changed

6 files changed

+308
-186
lines changed

src/wazuh_mcp_server/api/wazuh_client.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,27 @@
1717
logger = logging.getLogger(__name__)
1818

1919
# Time range to hours mapping for indexer-based queries
20-
_TIME_RANGE_HOURS = {"1h": 1, "6h": 6, "12h": 12, "24h": 24, "7d": 168, "30d": 720}
20+
_TIME_RANGE_HOURS = {"1h": 1, "6h": 6, "12h": 12, "1d": 24, "24h": 24, "7d": 168, "30d": 720}
21+
22+
23+
def _dict_contains_text(d: Any, text: str) -> bool:
24+
"""Recursively check if any string value in a dict/list contains the given text.
25+
Much faster than json.dumps() + string search for large nested dicts."""
26+
if isinstance(d, dict):
27+
for v in d.values():
28+
if _dict_contains_text(v, text):
29+
return True
30+
elif isinstance(d, (list, tuple)):
31+
for item in d:
32+
if _dict_contains_text(item, text):
33+
return True
34+
elif isinstance(d, str):
35+
if text in d.lower():
36+
return True
37+
elif isinstance(d, (int, float)):
38+
if text in str(d):
39+
return True
40+
return False
2141

2242

2343
class WazuhClient:
@@ -462,9 +482,7 @@ async def analyze_alert_patterns(self, time_range: str, min_frequency: int) -> D
462482
"level": rule.get("level", 0),
463483
}
464484
rule_counts[rule_id]["count"] += 1
465-
patterns = [
466-
{"rule_id": k, **v} for k, v in rule_counts.items() if v["count"] >= min_frequency
467-
]
485+
patterns = [{"rule_id": k, **v} for k, v in rule_counts.items() if v["count"] >= min_frequency]
468486
patterns.sort(key=lambda x: x["count"], reverse=True)
469487
return {
470488
"data": {
@@ -476,11 +494,20 @@ async def analyze_alert_patterns(self, time_range: str, min_frequency: int) -> D
476494
}
477495

478496
async def search_security_events(self, query: str, time_range: str, limit: int) -> Dict[str, Any]:
479-
"""Search security events via the Wazuh Indexer."""
497+
"""Search security events via the Wazuh Indexer with query filtering."""
480498
if not self._indexer_client:
481499
raise IndexerNotConfiguredError()
482500
start = self._time_range_to_start(time_range)
483-
return await self._indexer_client.get_alerts(limit=limit, timestamp_start=start)
501+
# Fetch a larger batch from the indexer, then filter by query
502+
fetch_limit = min(limit * 5, 2000)
503+
result = await self._indexer_client.get_alerts(limit=fetch_limit, timestamp_start=start)
504+
if query:
505+
alerts = result.get("data", {}).get("affected_items", [])
506+
query_lower = query.lower()
507+
filtered = [a for a in alerts if _dict_contains_text(a, query_lower)]
508+
result["data"]["affected_items"] = filtered[:limit]
509+
result["data"]["total_affected_items"] = len(filtered)
510+
return result
484511

485512
async def get_running_agents(self) -> Dict[str, Any]:
486513
"""Get running agents."""
@@ -589,11 +616,8 @@ async def analyze_security_threat(self, indicator: str, indicator_type: str) ->
589616
raise IndexerNotConfiguredError()
590617
result = await self._indexer_client.get_alerts(limit=100)
591618
alerts = result.get("data", {}).get("affected_items", [])
592-
matches = []
593-
for alert in alerts:
594-
alert_str = json.dumps(alert)
595-
if indicator.lower() in alert_str.lower():
596-
matches.append(alert)
619+
indicator_lower = indicator.lower()
620+
matches = [a for a in alerts if _dict_contains_text(a, indicator_lower)]
597621
return {
598622
"data": {
599623
"indicator": indicator,
@@ -611,9 +635,9 @@ async def check_ioc_reputation(self, indicator: str, indicator_type: str) -> Dic
611635
alerts = result.get("data", {}).get("affected_items", [])
612636
occurrences = 0
613637
max_level = 0
638+
indicator_lower = indicator.lower()
614639
for alert in alerts:
615-
alert_str = json.dumps(alert)
616-
if indicator.lower() in alert_str.lower():
640+
if _dict_contains_text(alert, indicator_lower):
617641
occurrences += 1
618642
level = alert.get("rule", {}).get("level", 0)
619643
if isinstance(level, int) and level > max_level:
@@ -887,14 +911,12 @@ async def check_blocked_ip(self, ip_address: str, agent_id: str = None) -> Dict[
887911
raise IndexerNotConfiguredError()
888912
result = await self._indexer_client.get_alerts(limit=50)
889913
alerts = result.get("data", {}).get("affected_items", [])
890-
matches = [a for a in alerts if ip_address in json.dumps(a) and "firewall-drop" in json.dumps(a)]
914+
matches = [a for a in alerts if _dict_contains_text(a, ip_address) and _dict_contains_text(a, "firewall-drop")]
891915
return {"data": {"ip_address": ip_address, "blocked": len(matches) > 0, "matching_alerts": len(matches)}}
892916

893917
async def check_agent_isolation(self, agent_id: str) -> Dict[str, Any]:
894918
"""Check agent isolation status."""
895-
result = await self._request(
896-
"GET", "/agents", params={"agents_list": agent_id, "select": "id,name,status"}
897-
)
919+
result = await self._request("GET", "/agents", params={"agents_list": agent_id, "select": "id,name,status"})
898920
agents = result.get("data", {}).get("affected_items", [])
899921
if not agents:
900922
raise ValueError(f"Agent {agent_id} not found")
@@ -910,9 +932,7 @@ async def check_agent_isolation(self, agent_id: str) -> Dict[str, Any]:
910932

911933
async def check_process(self, agent_id: str, process_id: int) -> Dict[str, Any]:
912934
"""Check if a process is still running on an agent."""
913-
result = await self._request(
914-
"GET", f"/syscollector/{agent_id}/processes", params={"limit": 500}
915-
)
935+
result = await self._request("GET", f"/syscollector/{agent_id}/processes", params={"limit": 500})
916936
processes = result.get("data", {}).get("affected_items", [])
917937
running = any(str(p.get("pid")) == str(process_id) for p in processes)
918938
return {"data": {"agent_id": agent_id, "process_id": process_id, "running": running}}
@@ -930,9 +950,7 @@ async def check_user_status(self, agent_id: str, username: str) -> Dict[str, Any
930950

931951
async def check_file_quarantine(self, agent_id: str, file_path: str) -> Dict[str, Any]:
932952
"""Check if a file has been quarantined via FIM events."""
933-
result = await self._request(
934-
"GET", "/syscheck", params={"agents_list": agent_id, "q": f"file={file_path}"}
935-
)
953+
result = await self._request("GET", "/syscheck", params={"agents_list": agent_id, "q": f"file={file_path}"})
936954
events = result.get("data", {}).get("affected_items", [])
937955
quarantined = any(e.get("type") == "deleted" or "quarantine" in str(e) for e in events)
938956
return {"data": {"agent_id": agent_id, "file_path": file_path, "quarantined": quarantined}}

src/wazuh_mcp_server/api/wazuh_indexer.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,22 +95,27 @@ async def _ensure_initialized(self):
9595
retry=retry_if_exception_type((httpx.RequestError, httpx.HTTPStatusError)),
9696
reraise=True,
9797
)
98-
async def _search(self, index: str, query: Dict[str, Any], size: int = 100) -> Dict[str, Any]:
98+
async def _search(
99+
self, index: str, query: Dict[str, Any], size: int = 100, sort: Optional[list] = None
100+
) -> Dict[str, Any]:
99101
"""
100102
Execute a search query against the Wazuh Indexer.
101103
102104
Args:
103105
index: Index pattern to search
104106
query: Elasticsearch query DSL
105107
size: Maximum number of results
108+
sort: Optional sort specification
106109
107110
Returns:
108111
Search results from the indexer
109112
"""
110113
await self._ensure_initialized()
111114

112115
url = f"{self.base_url}/{index}/_search"
113-
body = {"query": query, "size": size}
116+
body: Dict[str, Any] = {"query": query, "size": size}
117+
if sort:
118+
body["sort"] = sort
114119

115120
try:
116121
response = await self.client.post(url, json=body, headers={"Content-Type": "application/json"})
@@ -151,8 +156,6 @@ async def get_alerts(
151156
Returns:
152157
Alert data in standard Wazuh format
153158
"""
154-
await self._ensure_initialized()
155-
156159
# Build bool query with must clauses for each non-empty filter
157160
must_clauses: list = []
158161

@@ -183,28 +186,8 @@ async def get_alerts(
183186
else:
184187
query = {"match_all": {}}
185188

186-
# Build the full search body with sort by timestamp desc
187-
url = f"{self.base_url}/{ALERTS_INDEX}/_search"
188-
body = {
189-
"query": query,
190-
"size": limit,
191-
"sort": [{"timestamp": {"order": "desc"}}],
192-
}
193-
194-
try:
195-
response = await self.client.post(url, json=body, headers={"Content-Type": "application/json"})
196-
response.raise_for_status()
197-
try:
198-
result = response.json()
199-
except (json.JSONDecodeError, ValueError):
200-
raise ValueError("Invalid JSON response from Wazuh Indexer alerts query")
201-
except httpx.HTTPStatusError as e:
202-
logger.error(f"Alerts query failed: {e.response.status_code} - {e.response.text}")
203-
raise ValueError(f"Alerts query failed: {e.response.status_code}")
204-
except httpx.ConnectError:
205-
raise ConnectionError(f"Cannot connect to Wazuh Indexer at {self.host}:{self.port}")
206-
except httpx.TimeoutException:
207-
raise ConnectionError(f"Timeout connecting to Wazuh Indexer at {self.host}:{self.port}")
189+
# Use _search helper for consistent retry logic (sorted by timestamp desc)
190+
result = await self._search(ALERTS_INDEX, query, size=limit, sort=[{"timestamp": {"order": "desc"}}])
208191

209192
# Transform to standard Wazuh format
210193
hits = result.get("hits", {})

src/wazuh_mcp_server/auth.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from datetime import datetime, timedelta, timezone
1515
from typing import Any, Dict, List, Optional
1616

17-
from fastapi import Header, HTTPException
1817
from jose import jwt
1918
from jose.exceptions import ExpiredSignatureError, JWTError
2019
from pydantic import BaseModel, Field
@@ -268,20 +267,6 @@ def get_stats(self) -> Dict[str, Any]:
268267
auth_manager = AuthManager()
269268

270269

271-
class TokenRequest(BaseModel):
272-
"""Token request model."""
273-
274-
api_key: str = Field(description="API key to exchange for token")
275-
276-
277-
class TokenResponse(BaseModel):
278-
"""Token response model."""
279-
280-
token: str = Field(description="Authentication token")
281-
expires_in: int = Field(description="Token lifetime in seconds")
282-
token_type: str = Field(default="Bearer")
283-
284-
285270
async def verify_bearer_token(authorization: str) -> AuthToken:
286271
"""
287272
Verify bearer token from Authorization header.
@@ -379,57 +364,3 @@ def verify_token(token: str, secret_key: str) -> Dict[str, Any]:
379364
raise ValueError("Token has expired")
380365
except JWTError:
381366
raise ValueError("Invalid token")
382-
383-
384-
async def create_auth_endpoints(app):
385-
"""Add authentication endpoints to FastAPI app."""
386-
387-
@app.post("/auth/token", response_model=TokenResponse)
388-
async def create_token(request: TokenRequest):
389-
"""Exchange API key for authentication token."""
390-
token = auth_manager.create_token(request.api_key)
391-
if not token:
392-
raise HTTPException(status_code=401, detail="Invalid API key")
393-
394-
token_obj = auth_manager.tokens[token]
395-
expires_in = int((token_obj.expires_at - datetime.now(timezone.utc)).total_seconds())
396-
397-
return TokenResponse(token=token, expires_in=expires_in)
398-
399-
@app.get("/auth/validate")
400-
async def validate_token(authorization: str = Header(description="Bearer token")):
401-
"""Validate authentication token."""
402-
try:
403-
token_obj = await verify_bearer_token(authorization)
404-
return {
405-
"valid": True,
406-
"api_key_id": token_obj.api_key_id,
407-
"scopes": token_obj.scopes,
408-
"expires_at": token_obj.expires_at.isoformat() if token_obj.expires_at else None,
409-
}
410-
except ValueError as e:
411-
raise HTTPException(status_code=401, detail=str(e))
412-
413-
@app.post("/auth/revoke")
414-
async def revoke_token(authorization: str = Header(description="Bearer token")):
415-
"""Revoke authentication token."""
416-
if not authorization.startswith("Bearer "):
417-
raise HTTPException(status_code=400, detail="Invalid authorization header")
418-
419-
token = authorization[7:]
420-
if auth_manager.revoke_token(token):
421-
return {"revoked": True}
422-
else:
423-
raise HTTPException(status_code=404, detail="Token not found")
424-
425-
@app.get("/auth/stats")
426-
async def auth_stats(authorization: str = Header(description="Bearer token")):
427-
"""Get authentication statistics (requires admin scope)."""
428-
try:
429-
token_obj = await verify_bearer_token(authorization)
430-
if not token_obj.has_scope("admin"):
431-
raise HTTPException(status_code=403, detail="Admin scope required")
432-
433-
return auth_manager.get_stats()
434-
except ValueError as e:
435-
raise HTTPException(status_code=401, detail=str(e))

src/wazuh_mcp_server/security.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, param_name: str, message: str, suggestion: str = None):
3939

4040

4141
# Valid enum values for tool parameters
42-
VALID_TIME_RANGES = {"1h", "6h", "24h", "7d", "1d", "30d"}
42+
VALID_TIME_RANGES = {"1h", "6h", "12h", "24h", "7d", "1d", "30d"}
4343
VALID_SEVERITIES = {"low", "medium", "high", "critical"}
4444
VALID_AGENT_STATUSES = {"active", "disconnected", "never_connected", "pending"}
4545
VALID_INDICATOR_TYPES = {"ip", "hash", "domain", "url"}
@@ -354,7 +354,9 @@ def validate_file_path(value: Any, required: bool = False, param_name: str = "fi
354354
raise ToolValidationError(param_name, "contains path traversal", "Path must not contain '..'")
355355

356356
if len(file_path) > 500:
357-
raise ToolValidationError(param_name, f"too long ({len(file_path)} chars)", "Path must be 500 characters or less")
357+
raise ToolValidationError(
358+
param_name, f"too long ({len(file_path)} chars)", "Path must be 500 characters or less"
359+
)
358360

359361
return file_path
360362

0 commit comments

Comments
 (0)