Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions services/data/postgres_async_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .db_utils import DBResponse, DBPagination, aiopg_exception_handling, \
get_db_ts_epoch_str, translate_run_key, translate_task_key, new_heartbeat_ts
from .query_tracing import record_query, QUERY_TRACING_ENABLED
from .models import FlowRow, RunRow, StepRow, TaskRow, MetadataRow, ArtifactRow
from services.utils import DBConfiguration, USE_SEPARATE_READER_POOL

Expand Down Expand Up @@ -249,6 +250,8 @@ async def execute_sql(self, select_sql: str, values=[], fetch_single=False,
expanded=False, limit: int = 0, offset: int = 0,
cur: aiopg.Cursor = None, serialize: bool = True) -> Tuple[DBResponse, DBPagination]:
async def _execute_on_cursor(_cur):
_trace_start = time.monotonic() if QUERY_TRACING_ENABLED else None

await _cur.execute(select_sql, values)

rows = []
Expand All @@ -264,6 +267,10 @@ async def _execute_on_cursor(_cur):

count = len(rows)

if _trace_start is not None:
record_query(self.table_name, select_sql, count,
(time.monotonic() - _trace_start) * 1000)

# Will raise IndexError in case fetch_single=True and there's no results
body = rows[0] if fetch_single else rows
pagination = DBPagination(
Expand Down
68 changes: 68 additions & 0 deletions services/data/query_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
import time
import logging
from contextvars import ContextVar
from aiohttp import web

logger = logging.getLogger("QueryTracing")

QUERY_TRACING_ENABLED = os.environ.get("QUERY_TRACING_ENABLED", "0") == "1"

_request_trace: ContextVar[dict] = ContextVar("request_trace", default=None)


def start_trace():
if not QUERY_TRACING_ENABLED:
return
_request_trace.set({
"start_time": time.monotonic(),
"queries": [],
})


def record_query(table_name: str, sql: str, row_count: int, elapsed_ms: float):
trace = _request_trace.get(None)
if trace is None:
return
trace["queries"].append({
"table": table_name,
"sql": sql[:200],
"rows": row_count,
"time_ms": round(elapsed_ms, 2),
})


def finish_trace(method: str, path: str):
trace = _request_trace.get(None)
if trace is None:
return

total_time = (time.monotonic() - trace["start_time"]) * 1000
total_queries = len(trace["queries"])
total_rows = sum(q["rows"] for q in trace["queries"])
total_db_time = sum(q["time_ms"] for q in trace["queries"])

logger.info(
"[RequestTrace] %s %s | queries=%d total_rows=%d "
"db_time=%.2fms request_time=%.2fms",
method, path, total_queries, total_rows,
total_db_time, total_time
)

for i, q in enumerate(trace["queries"], 1):
logger.debug(
" [Query %d/%d] table=%s rows=%d time=%.2fms sql=%s",
i, total_queries, q["table"], q["rows"], q["time_ms"], q["sql"]
)

_request_trace.set(None)


@web.middleware
async def query_tracing_middleware(request, handler):
start_trace()
try:
response = await handler(request)
return response
finally:
finish_trace(request.method, request.path)
5 changes: 5 additions & 0 deletions services/metadata_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .api.metadata import MetadataApi
from services.data.postgres_async_db import AsyncPostgresDB
from services.data.query_tracing import query_tracing_middleware, QUERY_TRACING_ENABLED
from services.utils import DBConfiguration

PATH_PREFIX = os.environ.get("PATH_PREFIX", "")
Expand All @@ -23,6 +24,10 @@ def app(loop=None, db_conf: DBConfiguration = None, middlewares=None, path_prefi
loop = loop or asyncio.get_event_loop()

_app = web.Application(loop=loop)

if QUERY_TRACING_ENABLED:
_app.middlewares.append(query_tracing_middleware)

app = web.Application(loop=loop) if path_prefix else _app
async_db = AsyncPostgresDB()
loop.run_until_complete(async_db._init(db_conf))
Expand Down
102 changes: 102 additions & 0 deletions services/metadata_service/tests/unit_tests/query_trace_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import logging
import pytest
from unittest.mock import patch

from services.data.query_tracing import (
start_trace,
record_query,
finish_trace,
_request_trace,
)


class TestQueryTracingEnabled:

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True)
def test_single_query_trace(self, caplog):
start_trace()
record_query("runs_v3", "SELECT * FROM runs_v3 WHERE flow_id = %s", 50, 12.5)

with caplog.at_level(logging.INFO, logger="QueryTracing"):
finish_trace("GET", "/flows/TestFlow/runs")

assert "[RequestTrace]" in caplog.text
assert "GET /flows/TestFlow/runs" in caplog.text
assert "queries=1" in caplog.text
assert "total_rows=50" in caplog.text

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True)
def test_multiple_queries_aggregated(self, caplog):
start_trace()
record_query("tasks_v3", "SELECT * FROM tasks_v3", 30, 5.2)
record_query("runs_v3", "SELECT * FROM runs_v3", 1, 3.1)

with caplog.at_level(logging.INFO, logger="QueryTracing"):
finish_trace("GET", "/flows/TestFlow/runs/1/steps/train/tasks")

assert "queries=2" in caplog.text
assert "total_rows=31" in caplog.text

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True)
def test_sql_truncated_at_200_chars(self):
"""Long SQL strings are truncated to 200 characters."""
start_trace()
long_sql = "SELECT " + "x" * 300
record_query("runs_v3", long_sql, 10, 5.0)

trace = _request_trace.get()
assert len(trace["queries"][0]["sql"]) == 200


_request_trace.set(None)

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True)
def test_trace_cleaned_up_after_finish(self):
start_trace()
record_query("runs_v3", "SELECT *", 10, 1.0)
finish_trace("GET", "/test")

assert _request_trace.get(None) is None

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True)
def test_zero_queries_logged(self, caplog):
start_trace()

with caplog.at_level(logging.INFO, logger="QueryTracing"):
finish_trace("GET", "/healthcheck")

assert "queries=0" in caplog.text
assert "total_rows=0" in caplog.text

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", True)
def test_debug_level_shows_individual_queries(self, caplog):
start_trace()
record_query("runs_v3", "SELECT * FROM runs_v3", 50, 12.5)
record_query("steps_v3", "SELECT * FROM steps_v3", 200, 8.3)

with caplog.at_level(logging.DEBUG, logger="QueryTracing"):
finish_trace("GET", "/test")

assert "[Query 1/2]" in caplog.text
assert "[Query 2/2]" in caplog.text
assert "table=runs_v3" in caplog.text
assert "table=steps_v3" in caplog.text


class TestQueryTracingDisabled:

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", False)
def test_start_trace_noop_when_disabled(self):
start_trace()
assert _request_trace.get(None) is None

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", False)
def test_record_query_noop_when_disabled(self):
record_query("runs_v3", "SELECT *", 50, 12.5)

@patch("services.data.query_tracing.QUERY_TRACING_ENABLED", False)
def test_finish_trace_noop_when_disabled(self, caplog):
with caplog.at_level(logging.INFO, logger="QueryTracing"):
finish_trace("GET", "/test")

assert "[RequestTrace]" not in caplog.text