-
Notifications
You must be signed in to change notification settings - Fork 630
Expand file tree
/
Copy pathobservability_middleware.py
More file actions
272 lines (232 loc) · 11.5 KB
/
observability_middleware.py
File metadata and controls
272 lines (232 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# -*- coding: utf-8 -*-
"""Location: ./mcpgateway/middleware/observability_middleware.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti
Observability Middleware for automatic request/response tracing.
This middleware automatically captures HTTP requests and responses as observability traces,
providing comprehensive visibility into all gateway operations.
Examples:
>>> from mcpgateway.middleware.observability_middleware import ObservabilityMiddleware # doctest: +SKIP
>>> app.add_middleware(ObservabilityMiddleware) # doctest: +SKIP
"""
# Standard
import logging
import time
import traceback
from typing import Callable, Optional
# Third-Party
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
# First-Party
from mcpgateway.config import settings
from mcpgateway.db import SessionLocal
from mcpgateway.instrumentation.sqlalchemy import attach_trace_to_session
from mcpgateway.middleware.path_filter import should_skip_observability
from mcpgateway.plugins.framework.observability import current_trace_id as plugins_trace_id
from mcpgateway.services.observability_service import current_trace_id, ObservabilityService, parse_traceparent
logger = logging.getLogger(__name__)
class ObservabilityMiddleware(BaseHTTPMiddleware):
"""Middleware for automatic HTTP request/response tracing.
Captures every HTTP request as a trace with timing, status codes,
and user context. Automatically creates spans for the request lifecycle.
This middleware is disabled by default and can be enabled via the
MCPGATEWAY_OBSERVABILITY_ENABLED environment variable.
"""
def __init__(self, app, enabled: bool = None, service: Optional[ObservabilityService] = None):
"""Initialize the observability middleware.
Args:
app: ASGI application
enabled: Whether observability is enabled (defaults to settings)
service: Optional ObservabilityService instance
"""
super().__init__(app)
self.enabled = enabled if enabled is not None else getattr(settings, "observability_enabled", False)
self.service = service or ObservabilityService()
logger.info(f"Observability middleware initialized (enabled={self.enabled})")
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request and create observability trace.
Args:
request: Incoming HTTP request
call_next: Next middleware/handler in chain
Returns:
HTTP response
Raises:
Exception: Re-raises any exception from request processing after logging
"""
# Skip if observability is disabled
if not self.enabled:
return await call_next(request)
# Skip health checks and static files to reduce noise
if should_skip_observability(request.url.path):
return await call_next(request)
# Extract request context
http_method = request.method
http_url = str(request.url)
user_email = None
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
# Try to extract user from request state (set by auth middleware)
if hasattr(request.state, "user") and hasattr(request.state.user, "email"):
user_email = request.state.user.email
# Extract W3C Trace Context from headers (for distributed tracing)
external_trace_id = None
external_parent_span_id = None
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
parsed = parse_traceparent(traceparent_header)
if parsed:
external_trace_id, external_parent_span_id, _flags = parsed
logger.debug(f"Extracted W3C trace context: trace_id={external_trace_id}, parent_span_id={external_parent_span_id}")
db = None
trace_id = None
span_id = None
start_time = time.time()
session_owned_by_middleware = False
try:
# Create request-scoped database session and store in request.state
# This session will be reused by route handlers via get_db() dependency,
# eliminating duplicate session creation (Issue #3467)
db = SessionLocal()
logger.debug(f"[OBSERVABILITY] DB session created: {id(db)}")
request.state.db = db
session_owned_by_middleware = True
# Start trace (use external trace_id if provided for distributed tracing)
trace_id = self.service.start_trace(
db=db,
name=f"{http_method} {request.url.path}",
trace_id=external_trace_id, # Use external trace ID if provided
parent_span_id=external_parent_span_id, # Track parent span from upstream
http_method=http_method,
http_url=http_url,
user_email=user_email,
user_agent=user_agent,
ip_address=ip_address,
attributes={
"http.route": request.url.path,
"http.query": str(request.url.query) if request.url.query else None,
},
resource_attributes={
"service.name": "mcp-gateway",
"service.version": getattr(settings, "version", "unknown"),
},
)
# Store trace_id in request state for use in route handlers
request.state.trace_id = trace_id
# Set trace_id in context variable for access throughout async call stack
current_trace_id.set(trace_id)
# Bridge: also set the framework's ContextVar so the plugin executor sees it
plugins_trace_id.set(trace_id)
# Attach trace_id to database session for SQL query instrumentation
attach_trace_to_session(db, trace_id)
# Start request span
span_id = self.service.start_span(db=db, trace_id=trace_id, name="http.request", kind="server", attributes={"http.method": http_method, "http.url": http_url})
except Exception as e:
# If trace setup failed, log and continue without tracing
logger.warning(f"Failed to setup observability trace: {e}")
# Close db if it was created
if db:
try:
db.rollback() # Error path - rollback any partial transaction
except Exception as rollback_error:
logger.debug(f"Failed to rollback during cleanup: {rollback_error}")
# Connection is broken - invalidate to remove from pool
try:
db.invalidate()
except Exception:
pass # nosec B110
try:
db.close()
except Exception as close_error:
logger.debug(f"Failed to close database session during cleanup: {close_error}")
# Clean up request.state.db to prevent get_db() from reusing a closed session
if hasattr(request.state, "db"):
delattr(request.state, "db")
# Continue without tracing
return await call_next(request)
# Process request (trace is set up at this point)
# Route handlers will reuse request.state.db via get_db() dependency
try:
response = await call_next(request)
status_code = response.status_code
# End span successfully
if span_id:
try:
self.service.end_span(
db,
span_id,
status="ok" if status_code < 400 else "error",
attributes={"http.status_code": status_code, "http.response_size": response.headers.get("content-length")},
)
except Exception as end_span_error:
logger.warning(f"Failed to end span {span_id}: {end_span_error}")
# End trace
if trace_id:
duration_ms = (time.time() - start_time) * 1000
try:
self.service.end_trace(
db,
trace_id,
status="ok" if status_code < 400 else "error",
http_status_code=status_code,
attributes={"response_time_ms": duration_ms},
)
except Exception as end_trace_error:
logger.warning(f"Failed to end trace {trace_id}: {end_trace_error}")
# Commit the shared session (used by both observability and route handler)
# Note: Some route handlers may have already committed. The is_active check
# ensures we only commit if the transaction is still open. Services that
# explicitly commit will have already closed their transaction.
# Only commit if the transaction is still active AND has uncommitted changes
if db.is_active and db.in_transaction():
db.commit()
return response
except Exception as e:
# Log exception in span
if span_id:
try:
self.service.end_span(db, span_id, status="error", status_message=str(e), attributes={"exception.type": type(e).__name__, "exception.message": str(e)})
# Add exception event
self.service.add_event(
db,
span_id,
name="exception",
severity="error",
message=str(e),
exception_type=type(e).__name__,
exception_message=str(e),
exception_stacktrace=traceback.format_exc(),
)
except Exception as log_error:
logger.warning(f"Failed to log exception in span: {log_error}")
# End trace with error
if trace_id:
try:
self.service.end_trace(db, trace_id, status="error", status_message=str(e), http_status_code=500)
except Exception as trace_error:
logger.warning(f"Failed to end trace: {trace_error}")
# Rollback the shared session on error
try:
db.rollback()
except Exception as rollback_error:
logger.warning(f"Failed to rollback database session: {rollback_error}")
# Connection is broken - invalidate to remove from pool
# This handles cases like PgBouncer query_wait_timeout where
# the connection is dead and rollback itself fails
try:
db.invalidate()
except Exception:
pass # nosec B110
# Re-raise the original exception
raise
finally:
# Always close database session and clean up request state
if db and session_owned_by_middleware:
try:
db.close()
except Exception as close_error:
logger.warning(f"Failed to close database session: {close_error}")
# Clean up request.state.db to prevent stale references
if hasattr(request.state, "db"):
delattr(request.state, "db")