Skip to content

Commit 3e1e7be

Browse files
committed
Production hardening: fix race conditions, retry logic, metrics, and CI pipelines
- Add asyncio.Lock to circuit breaker state transitions and indexer initialization - Narrow retry scope to transient errors only (5xx, connection/timeout) - Fix circuit breaker monitoring always returning "unknown" state - Bound Prometheus endpoint labels to prevent cardinality explosion - Add JSONDecodeError handling on all response.json() call sites - Replace non-deterministic hash() cache keys with sorted() representation - Add SSE keepalive loop cancellation on client disconnect - Add MAX_BATCH_SIZE=100 limit on batch JSON-RPC requests - Throttle session cleanup to every 60s instead of every request - Remove premature REQUEST_COUNT increment (was always counting 200) - Remove || true from release and security CI workflows - Add fastmcp to pyproject.toml dependencies
1 parent 924fa92 commit 3e1e7be

File tree

8 files changed

+150
-72
lines changed

8 files changed

+150
-72
lines changed

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- name: Run tests
4545
run: |
4646
pip install pytest pytest-asyncio
47-
pytest tests/ -v || true
47+
pytest tests/ -v
4848
4949
- name: Build package
5050
run: python -m build

.github/workflows/security.yml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@ jobs:
3535
run: pip install -e .
3636

3737
- name: Run safety check
38+
continue-on-error: true
3839
run: |
39-
safety check --json --output safety-report.json || true
40+
safety check --json --output safety-report.json
4041
if [ -f safety-report.json ]; then
4142
echo "### Safety Report" >> $GITHUB_STEP_SUMMARY
4243
echo '```json' >> $GITHUB_STEP_SUMMARY
4344
cat safety-report.json >> $GITHUB_STEP_SUMMARY
4445
echo '```' >> $GITHUB_STEP_SUMMARY
4546
fi
46-
47+
4748
- name: Run pip-audit
49+
continue-on-error: true
4850
run: |
49-
pip-audit --format json --output pip-audit-report.json || true
51+
pip-audit --format json --output pip-audit-report.json
5052
if [ -f pip-audit-report.json ]; then
5153
echo "### Pip Audit Report" >> $GITHUB_STEP_SUMMARY
5254
echo '```json' >> $GITHUB_STEP_SUMMARY
@@ -76,12 +78,13 @@ jobs:
7678
python-version: '3.13'
7779

7880
- name: Run Bandit
81+
continue-on-error: true
7982
run: |
8083
pip install bandit
81-
bandit -r src/ -f json -o bandit-report.json || true
84+
bandit -r src/ -f json -o bandit-report.json
8285
echo "### Bandit Security Report" >> $GITHUB_STEP_SUMMARY
8386
echo '```' >> $GITHUB_STEP_SUMMARY
84-
bandit -r src/ -f txt || true
87+
bandit -r src/ -f txt
8588
echo '```' >> $GITHUB_STEP_SUMMARY
8689
8790
- name: Run Semgrep
@@ -158,7 +161,7 @@ jobs:
158161
# Install gitleaks CLI (free version)
159162
wget -q https://github.com/gitleaks/gitleaks/releases/download/v8.18.4/gitleaks_8.18.4_linux_x64.tar.gz
160163
tar -xzf gitleaks_8.18.4_linux_x64.tar.gz
161-
./gitleaks detect --source . --verbose --report-path gitleaks-report.json || true
164+
./gitleaks detect --source . --verbose --report-path gitleaks-report.json
162165
echo "### Gitleaks Report" >> $GITHUB_STEP_SUMMARY
163166
if [ -f gitleaks-report.json ]; then
164167
echo '```json' >> $GITHUB_STEP_SUMMARY

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ dependencies = [
4646
# Utilities
4747
"python-dotenv>=1.0.0",
4848
"aiofiles>=24.0.0",
49+
# MCP Framework
50+
"fastmcp>=2.14.0",
4951
]
5052

5153
[project.optional-dependencies]

src/wazuh_mcp_server/api/wazuh_client.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ async def _authenticate(self):
8484
response = await self.client.post(auth_url, auth=(self.config.wazuh_user, self.config.wazuh_pass))
8585
response.raise_for_status()
8686

87-
data = response.json()
87+
try:
88+
data = response.json()
89+
except (json.JSONDecodeError, ValueError):
90+
raise ValueError("Invalid JSON in authentication response from Wazuh API")
8891
if "data" not in data or "token" not in data["data"]:
8992
raise ValueError("Invalid authentication response from Wazuh API")
9093

@@ -224,7 +227,7 @@ async def _get_cached(self, cache_key: str, endpoint: str, **kwargs) -> Dict[str
224227
async def get_rules(self, **params) -> Dict[str, Any]:
225228
"""Get Wazuh detection rules (cached for 5 minutes)."""
226229
# Use caching for rules as they rarely change
227-
cache_key = f"rules:{hash(frozenset(params.items()) if params else 'all')}"
230+
cache_key = f"rules:{sorted(params.items()) if params else 'all'}"
228231
return await self._get_cached(cache_key, "/rules", params=params)
229232

230233
async def get_rule_info(self, rule_id: str) -> Dict[str, Any]:
@@ -234,7 +237,7 @@ async def get_rule_info(self, rule_id: str) -> Dict[str, Any]:
234237
async def get_decoders(self, **params) -> Dict[str, Any]:
235238
"""Get Wazuh log decoders (cached for 5 minutes)."""
236239
# Use caching for decoders as they rarely change
237-
cache_key = f"decoders:{hash(frozenset(params.items()) if params else 'all')}"
240+
cache_key = f"decoders:{sorted(params.items()) if params else 'all'}"
238241
return await self._get_cached(cache_key, "/decoders", params=params)
239242

240243
async def execute_active_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -370,7 +373,10 @@ async def _execute_request(self, method: str, endpoint: str, **kwargs) -> Dict[s
370373
response = await self.client.request(method, url, headers=headers, **kwargs)
371374
response.raise_for_status()
372375

373-
data = response.json()
376+
try:
377+
data = response.json()
378+
except (json.JSONDecodeError, ValueError):
379+
raise ValueError(f"Invalid JSON response from Wazuh API: {endpoint}")
374380

375381
# Validate response structure
376382
if "data" not in data:
@@ -391,7 +397,10 @@ async def _execute_request(self, method: str, endpoint: str, **kwargs) -> Dict[s
391397
headers = {"Authorization": f"Bearer {self.token}"}
392398
response = await self.client.request(method, url, headers=headers, **kwargs)
393399
response.raise_for_status()
394-
return response.json()
400+
try:
401+
return response.json()
402+
except (json.JSONDecodeError, ValueError):
403+
raise ValueError(f"Invalid JSON response from Wazuh API after re-auth: {endpoint}")
395404
else:
396405
logger.error(f"Wazuh API request failed: {e.response.status_code} - {e.response.text}")
397406
raise ValueError(f"Wazuh API request failed: {e.response.status_code} - {e.response.text}")

src/wazuh_mcp_server/api/wazuh_indexer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
- wazuh-states-vulnerabilities-* — vulnerability data (removed from Manager API in 4.8.0)
88
"""
99

10+
import asyncio
11+
import json
1012
import logging
1113
from typing import Any, Dict, Optional
1214

@@ -45,6 +47,7 @@ def __init__(
4547
self.timeout = timeout
4648
self.client: Optional[httpx.AsyncClient] = None
4749
self._initialized = False
50+
self._init_lock = asyncio.Lock()
4851

4952
@staticmethod
5053
def _normalize_host(host: str) -> str:
@@ -79,9 +82,12 @@ async def close(self):
7982
self._initialized = False
8083

8184
async def _ensure_initialized(self):
82-
"""Ensure client is initialized."""
83-
if not self._initialized:
84-
await self.initialize()
85+
"""Ensure client is initialized (thread-safe)."""
86+
if self._initialized:
87+
return
88+
async with self._init_lock:
89+
if not self._initialized:
90+
await self.initialize()
8591

8692
@retry(
8793
stop=stop_after_attempt(3),
@@ -109,7 +115,10 @@ async def _search(self, index: str, query: Dict[str, Any], size: int = 100) -> D
109115
try:
110116
response = await self.client.post(url, json=body, headers={"Content-Type": "application/json"})
111117
response.raise_for_status()
112-
return response.json()
118+
try:
119+
return response.json()
120+
except (json.JSONDecodeError, ValueError):
121+
raise ValueError(f"Invalid JSON response from Wazuh Indexer: {index}")
113122

114123
except httpx.HTTPStatusError as e:
115124
logger.error(f"Indexer search failed: {e.response.status_code} - {e.response.text}")
@@ -185,7 +194,10 @@ async def get_alerts(
185194
try:
186195
response = await self.client.post(url, json=body, headers={"Content-Type": "application/json"})
187196
response.raise_for_status()
188-
result = response.json()
197+
try:
198+
result = response.json()
199+
except (json.JSONDecodeError, ValueError):
200+
raise ValueError("Invalid JSON response from Wazuh Indexer alerts query")
189201
except httpx.HTTPStatusError as e:
190202
logger.error(f"Alerts query failed: {e.response.status_code} - {e.response.text}")
191203
raise ValueError(f"Alerts query failed: {e.response.status_code}")

src/wazuh_mcp_server/monitoring.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ def debug(self, message: str, **extra: Any) -> None:
7474
# Prometheus metrics registry
7575
REGISTRY = CollectorRegistry()
7676

77+
# Known endpoints for metric label normalization (prevents unbounded cardinality)
78+
_KNOWN_ENDPOINTS = {"/", "/mcp", "/health", "/metrics"}
79+
80+
81+
def _normalize_endpoint(path: str) -> str:
82+
"""Collapse unknown paths to 'other' to prevent label explosion."""
83+
if path in _KNOWN_ENDPOINTS:
84+
return path
85+
return "other"
86+
87+
7788
# Core metrics
7889
REQUEST_COUNT = Counter(
7990
"wazuh_mcp_requests_total", "Total number of requests", ["method", "endpoint", "status_code"], registry=REGISTRY
@@ -495,7 +506,7 @@ async def check_circuit_breaker() -> Dict[str, Any]:
495506
client = await get_wazuh_client()
496507
if hasattr(client, "_circuit_breaker"):
497508
cb = client._circuit_breaker
498-
state = cb._state if hasattr(cb, "_state") else "unknown"
509+
state = cb.state.value if hasattr(cb, "state") else "unknown"
499510
failure_count = cb.failure_count if hasattr(cb, "failure_count") else 0
500511

501512
# Update Prometheus metric
@@ -608,8 +619,9 @@ async def monitoring_middleware(request: Request, call_next):
608619
# Record metrics
609620
duration = time.time() - start_time
610621

611-
REQUEST_COUNT.labels(method=method, endpoint=path, status_code=status_code).inc()
612-
REQUEST_DURATION.labels(method=method, endpoint=path).observe(duration)
622+
normalized = _normalize_endpoint(path)
623+
REQUEST_COUNT.labels(method=method, endpoint=normalized, status_code=status_code).inc()
624+
REQUEST_DURATION.labels(method=method, endpoint=normalized).observe(duration)
613625

614626
# Record slow requests with correlation ID
615627
if duration > performance_profiler.slow_threshold:

src/wazuh_mcp_server/resilience.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import httpx
1616
from fastapi import HTTPException
17-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
17+
from tenacity import retry, retry_if_exception, retry_if_exception_type, stop_after_attempt, wait_exponential
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -53,6 +53,7 @@ def __init__(self, config: CircuitBreakerConfig):
5353
self.failure_count = 0
5454
self.last_failure_time: Optional[float] = None
5555
self.next_retry_time: Optional[float] = None
56+
self._lock = asyncio.Lock()
5657

5758
def __call__(self, func: Callable) -> Callable:
5859
"""Decorator to apply circuit breaker to function."""
@@ -66,19 +67,20 @@ async def wrapper(*args, **kwargs):
6667
async def _call(self, func: Callable, *args, **kwargs) -> Any:
6768
"""Execute function with circuit breaker logic."""
6869

69-
# Check if circuit is open
70-
if self.state == CircuitBreakerState.OPEN:
71-
if self._should_attempt_reset():
72-
self.state = CircuitBreakerState.HALF_OPEN
73-
logger.info(f"Circuit breaker {func.__name__} moved to HALF_OPEN")
74-
else:
75-
if self.config.fallback_function:
76-
logger.warning(f"Circuit breaker {func.__name__} OPEN, using fallback")
77-
return await self.config.fallback_function(*args, **kwargs)
70+
# Check if circuit is open (under lock to prevent race conditions)
71+
async with self._lock:
72+
if self.state == CircuitBreakerState.OPEN:
73+
if self._should_attempt_reset():
74+
self.state = CircuitBreakerState.HALF_OPEN
75+
logger.info(f"Circuit breaker {func.__name__} moved to HALF_OPEN")
7876
else:
79-
raise HTTPException(
80-
status_code=503, detail="Service temporarily unavailable - circuit breaker open"
81-
)
77+
if self.config.fallback_function:
78+
logger.warning(f"Circuit breaker {func.__name__} OPEN, using fallback")
79+
return await self.config.fallback_function(*args, **kwargs)
80+
else:
81+
raise HTTPException(
82+
status_code=503, detail="Service temporarily unavailable - circuit breaker open"
83+
)
8284

8385
try:
8486
result = await func(*args, **kwargs)
@@ -101,23 +103,24 @@ def _should_attempt_reset(self) -> bool:
101103

102104
async def _on_success(self, func_name: str):
103105
"""Handle successful execution."""
104-
if self.state == CircuitBreakerState.HALF_OPEN:
105-
self.state = CircuitBreakerState.CLOSED
106-
logger.info(f"Circuit breaker {func_name} reset to CLOSED")
107-
108-
self.failure_count = 0
109-
self.last_failure_time = None
106+
async with self._lock:
107+
if self.state == CircuitBreakerState.HALF_OPEN:
108+
self.state = CircuitBreakerState.CLOSED
109+
logger.info(f"Circuit breaker {func_name} reset to CLOSED")
110+
self.failure_count = 0
111+
self.last_failure_time = None
110112

111113
async def _on_failure(self, func_name: str, exception: Exception):
112114
"""Handle failed execution."""
113-
self.failure_count += 1
114-
self.last_failure_time = time.time()
115-
116-
if self.failure_count >= self.config.failure_threshold:
117-
self.state = CircuitBreakerState.OPEN
118-
logger.warning(
119-
f"Circuit breaker {func_name} opened after {self.failure_count} failures. " f"Last error: {exception}"
120-
)
115+
async with self._lock:
116+
self.failure_count += 1
117+
self.last_failure_time = time.time()
118+
if self.failure_count >= self.config.failure_threshold:
119+
self.state = CircuitBreakerState.OPEN
120+
logger.warning(
121+
f"Circuit breaker {func_name} opened after {self.failure_count} failures. "
122+
f"Last error: {exception}"
123+
)
121124

122125

123126
class TimeoutManager:
@@ -155,13 +158,22 @@ async def wrapper(*args, **kwargs):
155158
return decorator
156159

157160

161+
def _is_retryable(exception):
162+
"""Only retry on transient errors (5xx, connection, timeout)."""
163+
if isinstance(exception, httpx.RequestError):
164+
return True # Connection/timeout errors
165+
if isinstance(exception, httpx.HTTPStatusError):
166+
return exception.response.status_code >= 500
167+
return False
168+
169+
158170
class RetryConfig:
159171
"""Retry configuration."""
160172

161173
WAZUH_API_RETRY = retry(
162174
stop=stop_after_attempt(3),
163175
wait=wait_exponential(multiplier=1, min=1, max=10),
164-
retry=retry_if_exception_type((httpx.RequestError, httpx.HTTPStatusError)),
176+
retry=retry_if_exception(_is_retryable),
165177
reraise=True,
166178
)
167179

0 commit comments

Comments
 (0)