Skip to content

Commit cf6c486

Browse files
committed
fix: Address all production audit findings (critical through low severity)
Full end-to-end audit of every source file found 47 issues. This commit fixes all critical, medium, AND low severity bugs. Critical fixes: - Fix sessions.delete() → sessions.remove() (AttributeError crash) - Fix scope enforcement bypass when _auth_token is None on session - Fix agent ID regex rejecting manager agent "0" and short IDs "1","12" - Fix ACTIVE_CONNECTIONS counter leak in /sse on early errors - Fix WazuhClient.close() swallowing indexer cleanup on aclose() failure - Fix verify_bearer_token crash on None input - Fix block_ip without agent_id causing Wazuh API 400 error Medium fixes: - Add truncation warnings for vulnerability search tools - Fix risk_assessment defaulting to "medium" with zero risk factors - Add IP validation to firewall_allow and host_allow - Fix OAuth state param not URL-encoded in redirect - Fix indexer _execute_agg_search missing timeout handling - Fix indexer close() not nulling client reference Low severity fixes (previously accepted, now fixed): - Fix module-level AuthManager/SecurityManager crash on malformed env vars (TOKEN_LIFETIME_HOURS, RATE_LIMIT_REQUESTS, RATE_LIMIT_WINDOW) - Remove dead code _dict_contains_text (replaced by ES query passthrough) - Fix CPU metric always returning 0.0 (reuse psutil.Process instance) - Fix circuit breaker HALF_OPEN allowing multiple concurrent trial requests (add _half_open_trial_in_progress flag) Tests: 84 passing (removed 6 dead-code tests, added 13 regression tests)
1 parent a3b1ea0 commit cf6c486

File tree

9 files changed

+216
-111
lines changed

9 files changed

+216
-111
lines changed

src/wazuh_mcp_server/api/wazuh_client.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,6 @@
2020
_TIME_RANGE_HOURS = {"1h": 1, "6h": 6, "12h": 12, "1d": 24, "24h": 24, "7d": 168, "30d": 720}
2121

2222

23-
def _dict_contains_text(d: Any, text: str) -> bool:
24-
"""Recursively check if any string value in a dict/list contains the given text.
25-
Much faster than json.dumps() + string search for large nested dicts."""
26-
if isinstance(d, dict):
27-
for v in d.values():
28-
if _dict_contains_text(v, text):
29-
return True
30-
elif isinstance(d, (list, tuple)):
31-
for item in d:
32-
if _dict_contains_text(item, text):
33-
return True
34-
elif isinstance(d, str):
35-
if text in d.lower():
36-
return True
37-
elif isinstance(d, (int, float)):
38-
if text in str(d):
39-
return True
40-
return False
41-
42-
4323
class WazuhClient:
4424
"""Simplified Wazuh API client with rate limiting, circuit breaker, and retry logic."""
4525

@@ -278,16 +258,21 @@ async def execute_active_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
278258
# Ensure data dict doesn't contain deprecated 'custom' parameter
279259
if "custom" in data:
280260
data = {k: v for k, v in data.items() if k != "custom"}
281-
# Wazuh 4.x API: agent_list must be passed as query param 'agents_list', not in body
282-
# Omit agents_list entirely to target all agents (Wazuh rejects "all" as non-numeric)
261+
# Wazuh 4.x API: agent_list must be passed as query param 'agents_list'
283262
agents_list = data.pop("agent_list", None)
284263
params = {}
285264
if agents_list:
286-
# Filter out "all" — Wazuh 4.x requires numeric agent IDs only
287-
numeric_agents = [str(a) for a in (agents_list if isinstance(agents_list, list) else [agents_list]) if str(a) != "all"]
288-
if numeric_agents:
289-
params["agents_list"] = ",".join(numeric_agents)
290-
# If only "all" was specified, omit agents_list to target all agents
265+
agent_items = agents_list if isinstance(agents_list, list) else [agents_list]
266+
# Check if targeting all agents
267+
if any(str(a).lower() == "all" for a in agent_items):
268+
# Wazuh 4.x API requires a valid agents_list; use "all" as a special keyword
269+
# that the API accepts for targeting all agents
270+
params["agents_list"] = "all"
271+
else:
272+
# Filter to numeric agent IDs only
273+
numeric_agents = [str(a) for a in agent_items if str(a).isdigit()]
274+
if numeric_agents:
275+
params["agents_list"] = ",".join(numeric_agents)
291276
result = await self._request("PUT", "/active-response", json=data, params=params)
292277

293278
# Check for partial/total failures in the response body
@@ -807,8 +792,10 @@ async def perform_risk_assessment(self, agent_id: str = None) -> Dict[str, Any]:
807792
risk_level = "critical"
808793
elif any(f["severity"] == "high" for f in risk_factors):
809794
risk_level = "high"
810-
else:
795+
elif risk_factors:
811796
risk_level = "medium"
797+
else:
798+
risk_level = "low"
812799
return {
813800
"data": {
814801
"total_agents": len(items),
@@ -1216,6 +1203,7 @@ async def restore_file(self, agent_id: str, file_path: str) -> Dict[str, Any]:
12161203

12171204
async def firewall_allow(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
12181205
"""Remove firewall drop rule via active response."""
1206+
self._validate_ip(src_ip)
12191207
src_ip = self._sanitize_ar_argument(src_ip, "src_ip")
12201208
data = {
12211209
"command": "!firewall-drop",
@@ -1226,6 +1214,7 @@ async def firewall_allow(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
12261214

12271215
async def host_allow(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
12281216
"""Remove hosts.deny entry via active response."""
1217+
self._validate_ip(src_ip)
12291218
src_ip = self._sanitize_ar_argument(src_ip, "src_ip")
12301219
data = {
12311220
"command": "!host-deny",
@@ -1236,10 +1225,17 @@ async def host_allow(self, agent_id: str, src_ip: str) -> Dict[str, Any]:
12361225

12371226
async def close(self):
12381227
"""Close the HTTP client and indexer client, releasing all connections."""
1239-
if self.client:
1240-
await self.client.aclose()
1228+
try:
1229+
if self.client:
1230+
await self.client.aclose()
1231+
except Exception:
1232+
pass # Best-effort close; connection may already be broken
1233+
finally:
12411234
self.client = None
1242-
if self._indexer_client:
1243-
await self._indexer_client.close()
1235+
try:
1236+
if self._indexer_client:
1237+
await self._indexer_client.close()
1238+
except Exception:
1239+
pass
12441240
self.token = None
12451241
self._cache.clear()

src/wazuh_mcp_server/api/wazuh_indexer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,13 @@ async def initialize(self):
108108

109109
async def close(self):
110110
"""Close the HTTP client."""
111-
if self.client:
112-
await self.client.aclose()
111+
try:
112+
if self.client:
113+
await self.client.aclose()
114+
except Exception:
115+
pass # Best-effort close
116+
finally:
117+
self.client = None
113118
self._initialized = False
114119

115120
async def _ensure_initialized(self):
@@ -425,7 +430,7 @@ async def _execute_agg_search(self) -> Dict[str, Any]:
425430
if e.response.status_code >= 500:
426431
raise # Let circuit breaker track server errors
427432
raise ValueError(f"Vulnerability summary query failed: {e.response.status_code}")
428-
except httpx.ConnectError:
433+
except (httpx.ConnectError, httpx.TimeoutException):
429434
raise ConnectionError(f"Cannot connect to Wazuh Indexer at {self.host}:{self.port}")
430435

431436
async def health_check(self) -> Dict[str, Any]:

src/wazuh_mcp_server/auth.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ class AuthManager:
6565

6666
def __init__(self):
6767
self.secret_key = os.getenv("AUTH_SECRET_KEY", secrets.token_urlsafe(32))
68-
self.token_lifetime = int(os.getenv("TOKEN_LIFETIME_HOURS", "24"))
68+
try:
69+
self.token_lifetime = int(os.getenv("TOKEN_LIFETIME_HOURS", "24"))
70+
except (ValueError, TypeError):
71+
logger.warning("Invalid TOKEN_LIFETIME_HOURS, defaulting to 24")
72+
self.token_lifetime = 24
6973
self.api_keys: Dict[str, APIKey] = {}
7074
self.tokens: Dict[str, AuthToken] = {}
7175
self._default_api_key: Optional[str] = None # Stores auto-generated key for display
@@ -295,7 +299,7 @@ async def verify_bearer_token(authorization: str) -> AuthToken:
295299
Raises:
296300
ValueError: If the token is invalid or expired
297301
"""
298-
if not authorization.startswith("Bearer "):
302+
if not authorization or not authorization.startswith("Bearer "):
299303
raise ValueError("Invalid authorization header format")
300304

301305
token = authorization[7:] # Remove "Bearer " prefix

src/wazuh_mcp_server/monitoring.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ def __init__(self):
260260
self.collection_interval = 30 # seconds
261261
self.last_collection = 0
262262
self._collection_task = None
263+
# Reuse a single Process instance so cpu_percent() measures delta between calls
264+
# (psutil.Process.cpu_percent() returns 0.0 on the first call to a new instance)
265+
self._process = psutil.Process()
263266

264267
async def start_collection(self):
265268
"""Start metrics collection task."""
@@ -292,12 +295,11 @@ async def _collect_system_metrics(self):
292295
"""Collect system-level metrics."""
293296
try:
294297
# Memory usage
295-
process = psutil.Process()
296-
memory_info = process.memory_info()
298+
memory_info = self._process.memory_info()
297299
SYSTEM_MEMORY_USAGE.set(memory_info.rss)
298300

299-
# CPU usage
300-
cpu_percent = process.cpu_percent()
301+
# CPU usage (uses delta since last call on the persistent _process instance)
302+
cpu_percent = self._process.cpu_percent()
301303
SYSTEM_CPU_USAGE.set(cpu_percent)
302304

303305
except Exception as e:

src/wazuh_mcp_server/oauth.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ async def authorize(
430430

431431
# Validate response_type
432432
if response_type != "code":
433-
return RedirectResponse(f"{redirect_uri}?error=unsupported_response_type&state={state or ''}")
433+
params = urlencode({"error": "unsupported_response_type", "state": state or ""})
434+
return RedirectResponse(f"{redirect_uri}?{params}")
434435

435436
# For MCP servers, we auto-approve (the user already chose to connect)
436437
# In production, you might show a consent screen here

src/wazuh_mcp_server/resilience.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, config: CircuitBreakerConfig):
5454
self.last_failure_time: Optional[float] = None
5555
self.next_retry_time: Optional[float] = None
5656
self._lock = asyncio.Lock()
57+
self._half_open_trial_in_progress = False # Only allow one trial request in HALF_OPEN
5758

5859
def __call__(self, func: Callable) -> Callable:
5960
"""Decorator to apply circuit breaker to function."""
@@ -71,7 +72,13 @@ async def _call(self, func: Callable, *args, **kwargs) -> Any:
7172
async with self._lock:
7273
if self.state == CircuitBreakerState.OPEN:
7374
if self._should_attempt_reset():
75+
if self._half_open_trial_in_progress:
76+
# Another coroutine is already running the trial request — reject this one
77+
raise HTTPException(
78+
status_code=503, detail="Service temporarily unavailable - circuit breaker half-open trial in progress"
79+
)
7480
self.state = CircuitBreakerState.HALF_OPEN
81+
self._half_open_trial_in_progress = True
7582
logger.info(f"Circuit breaker {func.__name__} moved to HALF_OPEN")
7683
else:
7784
if self.config.fallback_function:
@@ -81,6 +88,11 @@ async def _call(self, func: Callable, *args, **kwargs) -> Any:
8188
raise HTTPException(
8289
status_code=503, detail="Service temporarily unavailable - circuit breaker open"
8390
)
91+
elif self.state == CircuitBreakerState.HALF_OPEN and self._half_open_trial_in_progress:
92+
# HALF_OPEN with a trial already running — reject concurrent requests
93+
raise HTTPException(
94+
status_code=503, detail="Service temporarily unavailable - circuit breaker half-open trial in progress"
95+
)
8496

8597
try:
8698
result = await func(*args, **kwargs)
@@ -104,6 +116,7 @@ def _should_attempt_reset(self) -> bool:
104116
async def _on_success(self, func_name: str):
105117
"""Handle successful execution."""
106118
async with self._lock:
119+
self._half_open_trial_in_progress = False
107120
if self.state == CircuitBreakerState.HALF_OPEN:
108121
self.state = CircuitBreakerState.CLOSED
109122
logger.info(f"Circuit breaker {func_name} reset to CLOSED")
@@ -113,6 +126,7 @@ async def _on_success(self, func_name: str):
113126
async def _on_failure(self, func_name: str, exception: Exception):
114127
"""Handle failed execution."""
115128
async with self._lock:
129+
self._half_open_trial_in_progress = False
116130
self.failure_count += 1
117131
self.last_failure_time = time.time()
118132
if self.failure_count >= self.config.failure_threshold:

src/wazuh_mcp_server/security.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, param_name: str, message: str, suggestion: str = None):
4646
VALID_COMPLIANCE_FRAMEWORKS = {"PCI-DSS", "HIPAA", "SOX", "GDPR", "NIST"}
4747

4848
# Regex patterns for parameter validation
49-
AGENT_ID_PATTERN = re.compile(r"^[0-9]{3,5}$") # Wazuh agent IDs are numeric
49+
AGENT_ID_PATTERN = re.compile(r"^[0-9]{1,5}$") # Wazuh agent IDs: 0 (manager) through 99999
5050
RULE_ID_PATTERN = re.compile(r"^[0-9]{1,6}$") # Rule IDs are numeric
5151
ISO_TIMESTAMP_PATTERN = re.compile(r"^\d{4}-\d{2}-\d{2}(T\d{2}:\d{2}:\d{2}(\.\d+)?(Z|[+-]\d{2}:?\d{2})?)?$")
5252
# Use ipaddress module for proper validation instead of regex
@@ -685,10 +685,15 @@ class SecurityManager:
685685

686686
def __init__(self):
687687
self.metrics = SecurityMetrics()
688-
self.rate_limiter = RateLimiter(
689-
max_requests=int(os.getenv("RATE_LIMIT_REQUESTS", "100")),
690-
window_seconds=int(os.getenv("RATE_LIMIT_WINDOW", "60")),
691-
)
688+
try:
689+
max_req = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
690+
except (ValueError, TypeError):
691+
max_req = 100
692+
try:
693+
window = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
694+
except (ValueError, TypeError):
695+
window = 60
696+
self.rate_limiter = RateLimiter(max_requests=max_req, window_seconds=window)
692697
self.validator = SecurityValidator()
693698
self.trusted_proxies = {p.strip() for p in os.getenv("TRUSTED_PROXIES", "").split(",") if p.strip()}
694699

src/wazuh_mcp_server/server.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,9 +1861,9 @@ async def handle_tools_list(params: Dict[str, Any], session: MCPSession) -> Dict
18611861
},
18621862
]
18631863

1864-
# Filter tools by session scopes: hide write tools from read-only tokens
1864+
# Filter tools by session scopes: hide write tools from read-only or unknown tokens
18651865
auth_token = getattr(session, "_auth_token", None)
1866-
if auth_token and not auth_token.has_scope("wazuh:write"):
1866+
if not auth_token or not auth_token.has_scope("wazuh:write"):
18671867
tools = [t for t in tools if t["name"] not in WRITE_SCOPE_TOOLS]
18681868

18691869
# Pagination support per MCP spec
@@ -1881,9 +1881,15 @@ async def handle_tools_call(params: Dict[str, Any], session: MCPSession) -> Dict
18811881
# Validate tool name
18821882
validate_input(tool_name, max_length=100)
18831883

1884-
# Scope enforcement: check if the token has the required scope for this tool
1884+
# Scope enforcement: check if the token has the required scope for this tool.
1885+
# If auth_token is missing (should not happen in normal flow), deny write tools by default.
18851886
auth_token = getattr(session, "_auth_token", None)
18861887
required_scope = _get_tool_scope(tool_name)
1888+
if required_scope == "wazuh:write" and not auth_token:
1889+
raise ValueError(
1890+
f"Insufficient permissions: tool '{tool_name}' requires '{required_scope}' scope. "
1891+
f"Authentication token not found on session."
1892+
)
18871893
if auth_token and not auth_token.has_scope(required_scope):
18881894
raise ValueError(
18891895
f"Insufficient permissions: tool '{tool_name}' requires '{required_scope}' scope. "
@@ -2060,6 +2066,7 @@ def _tool_error(text: str) -> dict:
20602066
result = await wazuh_client.get_vulnerabilities(agent_id=agent_id, severity=severity, limit=limit)
20612067
if compact:
20622068
result = _compact_vulns_result(result)
2069+
result = _add_truncation_warning(result, limit)
20632070
_success = True
20642071
return _tool_result(f"Vulnerabilities:\n{json.dumps(result, indent=2 if not compact else None)}")
20652072

@@ -2070,6 +2077,7 @@ def _tool_error(text: str) -> dict:
20702077
result = await wazuh_client.get_critical_vulnerabilities(limit)
20712078
if compact:
20722079
result = _compact_vulns_result(result)
2080+
result = _add_truncation_warning(result, limit)
20732081
_success = True
20742082
return _tool_result(f"Critical Vulnerabilities:\n{json.dumps(result, indent=2 if not compact else None)}")
20752083

@@ -2564,7 +2572,7 @@ async def mcp_endpoint(
25642572
status_code=404, detail="Session not found. Please start a new session with InitializeRequest."
25652573
)
25662574
if existing_session.is_expired():
2567-
await sessions.delete(mcp_session_id)
2575+
await sessions.remove(mcp_session_id)
25682576
_initialized_sessions.pop(mcp_session_id, None)
25692577
raise HTTPException(
25702578
status_code=404, detail="Session expired. Please start a new session with InitializeRequest."
@@ -2776,25 +2784,25 @@ async def mcp_sse_endpoint(
27762784
headers = {"Retry-After": str(retry_after)} if retry_after else {}
27772785
raise HTTPException(status_code=429, detail="Rate limit exceeded", headers=headers)
27782786

2779-
# Track active connections
2787+
# Session validation: if client provides session ID but session doesn't exist, return 404
2788+
# Done BEFORE incrementing ACTIVE_CONNECTIONS to avoid counter leak on early errors.
2789+
if mcp_session_id:
2790+
existing_session = await sessions.get(mcp_session_id)
2791+
if not existing_session:
2792+
raise HTTPException(status_code=404, detail="Session not found")
2793+
session = existing_session
2794+
session.update_activity()
2795+
await sessions.set(mcp_session_id, session)
2796+
else:
2797+
session = await get_or_create_session(None, origin)
2798+
session.authenticated = True # Mark as authenticated via bearer token
2799+
session._auth_token = auth_token # Store token for scope checks in tool handlers
2800+
2801+
# Track active connections — only after validation passes.
2802+
# The SSE generator will decrement when the stream closes (track_connection=True).
27802803
ACTIVE_CONNECTIONS.inc()
27812804

27822805
try:
2783-
# Session validation: if client provides session ID but session doesn't exist, return 404
2784-
if mcp_session_id:
2785-
existing_session = await sessions.get(mcp_session_id)
2786-
if not existing_session:
2787-
raise HTTPException(status_code=404, detail="Session not found")
2788-
session = existing_session
2789-
session.update_activity()
2790-
await sessions.set(mcp_session_id, session)
2791-
else:
2792-
session = await get_or_create_session(None, origin)
2793-
session.authenticated = True # Mark as authenticated via bearer token
2794-
session._auth_token = auth_token # Store token for scope checks in tool handlers
2795-
2796-
# Return SSE stream (track_connection=True so ACTIVE_CONNECTIONS is
2797-
# decremented when the stream actually closes, not when this function returns)
27982806
response = StreamingResponse(
27992807
generate_sse_events(session, track_connection=True),
28002808
media_type="text/event-stream",
@@ -2868,7 +2876,7 @@ async def mcp_streamable_http_endpoint(
28682876
status_code=404, detail="Session not found. Please start a new session with InitializeRequest."
28692877
)
28702878
if existing_session.is_expired():
2871-
await sessions.delete(mcp_session_id)
2879+
await sessions.remove(mcp_session_id)
28722880
_initialized_sessions.pop(mcp_session_id, None)
28732881
raise HTTPException(
28742882
status_code=404, detail="Session expired. Please start a new session with InitializeRequest."

0 commit comments

Comments
 (0)