From 38c3eb6d4a68b9cc9e0a63cae114f11166d67979 Mon Sep 17 00:00:00 2001 From: ahmedelgyar1 Date: Fri, 27 Mar 2026 09:54:18 +0200 Subject: [PATCH] feat: add per-request DB query tracing with ContextVar - Introduce query tracing using ContextVar for per-request isolation - Track query count, total rows, and execution time - Add SQL truncation (200 chars) to avoid excessive logs - Support DEBUG-level logging for individual queries - Ensure feature is opt-in and no-op when disabled - Add unit tests for enabled/disabled behavior and edge cases --- services/data/postgres_async_db.py | 7 ++ services/data/query_tracing.py | 68 ++++++++++++ services/metadata_service/server.py | 5 + .../tests/unit_tests/query_trace_test.py | 102 ++++++++++++++++++ 4 files changed, 182 insertions(+) create mode 100644 services/data/query_tracing.py create mode 100644 services/metadata_service/tests/unit_tests/query_trace_test.py diff --git a/services/data/postgres_async_db.py b/services/data/postgres_async_db.py index dbde13e60..9230f7ebb 100644 --- a/services/data/postgres_async_db.py +++ b/services/data/postgres_async_db.py @@ -12,6 +12,7 @@ from .db_utils import DBResponse, DBPagination, aiopg_exception_handling, \ get_db_ts_epoch_str, translate_run_key, translate_task_key, new_heartbeat_ts +from .query_tracing import record_query, QUERY_TRACING_ENABLED from .models import FlowRow, RunRow, StepRow, TaskRow, MetadataRow, ArtifactRow from services.utils import DBConfiguration, USE_SEPARATE_READER_POOL @@ -249,6 +250,8 @@ async def execute_sql(self, select_sql: str, values=[], fetch_single=False, expanded=False, limit: int = 0, offset: int = 0, cur: aiopg.Cursor = None, serialize: bool = True) -> Tuple[DBResponse, DBPagination]: async def _execute_on_cursor(_cur): + _trace_start = time.monotonic() if QUERY_TRACING_ENABLED else None + await _cur.execute(select_sql, values) rows = [] @@ -264,6 +267,10 @@ async def _execute_on_cursor(_cur): count = len(rows) + if _trace_start is not None: + record_query(self.table_name, select_sql, count, + (time.monotonic() - _trace_start) * 1000) + # Will raise IndexError in case fetch_single=True and there's no results body = rows[0] if fetch_single else rows pagination = DBPagination( diff --git a/services/data/query_tracing.py b/services/data/query_tracing.py new file mode 100644 index 000000000..b01ae89ac --- /dev/null +++ b/services/data/query_tracing.py @@ -0,0 +1,68 @@ +import os +import time +import logging +from contextvars import ContextVar +from aiohttp import web + +logger = logging.getLogger("QueryTracing") + +QUERY_TRACING_ENABLED = os.environ.get("QUERY_TRACING_ENABLED", "0") == "1" + +_request_trace: ContextVar[dict] = ContextVar("request_trace", default=None) + + +def start_trace(): + if not QUERY_TRACING_ENABLED: + return + _request_trace.set({ + "start_time": time.monotonic(), + "queries": [], + }) + + +def record_query(table_name: str, sql: str, row_count: int, elapsed_ms: float): + trace = _request_trace.get(None) + if trace is None: + return + trace["queries"].append({ + "table": table_name, + "sql": sql[:200], + "rows": row_count, + "time_ms": round(elapsed_ms, 2), + }) + + +def finish_trace(method: str, path: str): + trace = _request_trace.get(None) + if trace is None: + return + + total_time = (time.monotonic() - trace["start_time"]) * 1000 + total_queries = len(trace["queries"]) + total_rows = sum(q["rows"] for q in trace["queries"]) + total_db_time = sum(q["time_ms"] for q in trace["queries"]) + + logger.info( + "[RequestTrace] %s %s | queries=%d total_rows=%d " + "db_time=%.2fms request_time=%.2fms", + method, path, total_queries, total_rows, + total_db_time, total_time + ) + + for i, q in enumerate(trace["queries"], 1): + logger.debug( + " [Query %d/%d] table=%s rows=%d time=%.2fms sql=%s", + i, total_queries, q["table"], q["rows"], q["time_ms"], q["sql"] + ) + + _request_trace.set(None) + + +@web.middleware +async def query_tracing_middleware(request, handler): + start_trace() + try: + response = await handler(request) + return response + finally: + finish_trace(request.method, request.path) diff --git a/services/metadata_service/server.py b/services/metadata_service/server.py index 9530184b5..e84b217d6 100644 --- a/services/metadata_service/server.py +++ b/services/metadata_service/server.py @@ -13,6 +13,7 @@ from .api.metadata import MetadataApi from services.data.postgres_async_db import AsyncPostgresDB +from services.data.query_tracing import query_tracing_middleware, QUERY_TRACING_ENABLED from services.utils import DBConfiguration PATH_PREFIX = os.environ.get("PATH_PREFIX", "") @@ -23,6 +24,10 @@ def app(loop=None, db_conf: DBConfiguration = None, middlewares=None, path_prefi loop = loop or asyncio.get_event_loop() _app = web.Application(loop=loop) + + if QUERY_TRACING_ENABLED: + _app.middlewares.append(query_tracing_middleware) + app = web.Application(loop=loop) if path_prefix else _app async_db = AsyncPostgresDB() loop.run_until_complete(async_db._init(db_conf)) diff --git a/services/metadata_service/tests/unit_tests/query_trace_test.py b/services/metadata_service/tests/unit_tests/query_trace_test.py new file mode 100644 index 000000000..e4dfc8e34 --- /dev/null +++ b/services/metadata_service/tests/unit_tests/query_trace_test.py @@ -0,0 +1,102 @@ +import logging +import pytest +from unittest.mock import patch + +from services.data.query_tracing import ( + start_trace, + record_query, + finish_trace, + _request_trace, +) + + +class TestQueryTracingEnabled: + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True) + def test_single_query_trace(self, caplog): + start_trace() + record_query("runs_v3", "SELECT * FROM runs_v3 WHERE flow_id = %s", 50, 12.5) + + with caplog.at_level(logging.INFO, logger="QueryTracing"): + finish_trace("GET", "/flows/TestFlow/runs") + + assert "[RequestTrace]" in caplog.text + assert "GET /flows/TestFlow/runs" in caplog.text + assert "queries=1" in caplog.text + assert "total_rows=50" in caplog.text + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True) + def test_multiple_queries_aggregated(self, caplog): + start_trace() + record_query("tasks_v3", "SELECT * FROM tasks_v3", 30, 5.2) + record_query("runs_v3", "SELECT * FROM runs_v3", 1, 3.1) + + with caplog.at_level(logging.INFO, logger="QueryTracing"): + finish_trace("GET", "/flows/TestFlow/runs/1/steps/train/tasks") + + assert "queries=2" in caplog.text + assert "total_rows=31" in caplog.text + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True) + def test_sql_truncated_at_200_chars(self): + """Long SQL strings are truncated to 200 characters.""" + start_trace() + long_sql = "SELECT " + "x" * 300 + record_query("runs_v3", long_sql, 10, 5.0) + + trace = _request_trace.get() + assert len(trace["queries"][0]["sql"]) == 200 + + + _request_trace.set(None) + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True) + def test_trace_cleaned_up_after_finish(self): + start_trace() + record_query("runs_v3", "SELECT *", 10, 1.0) + finish_trace("GET", "/test") + + assert _request_trace.get(None) is None + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True) + def test_zero_queries_logged(self, caplog): + start_trace() + + with caplog.at_level(logging.INFO, logger="QueryTracing"): + finish_trace("GET", "/healthcheck") + + assert "queries=0" in caplog.text + assert "total_rows=0" in caplog.text + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True) + def test_debug_level_shows_individual_queries(self, caplog): + start_trace() + record_query("runs_v3", "SELECT * FROM runs_v3", 50, 12.5) + record_query("steps_v3", "SELECT * FROM steps_v3", 200, 8.3) + + with caplog.at_level(logging.DEBUG, logger="QueryTracing"): + finish_trace("GET", "/test") + + assert "[Query 1/2]" in caplog.text + assert "[Query 2/2]" in caplog.text + assert "table=runs_v3" in caplog.text + assert "table=steps_v3" in caplog.text + + +class TestQueryTracingDisabled: + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", False) + def test_start_trace_noop_when_disabled(self): + start_trace() + assert _request_trace.get(None) is None + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", False) + def test_record_query_noop_when_disabled(self): + record_query("runs_v3", "SELECT *", 50, 12.5) + + @patch("services.data.query_tracing.QUERY_TRACING_ENABLED", False) + def test_finish_trace_noop_when_disabled(self, caplog): + with caplog.at_level(logging.INFO, logger="QueryTracing"): + finish_trace("GET", "/test") + + assert "[RequestTrace]" not in caplog.text