From 2b961acc4b0df00fb6903b02bc3d4fe8aeedd218 Mon Sep 17 00:00:00 2001 From: Jairus Martinez <114552516+jairus-m@users.noreply.github.com> Date: Thu, 14 May 2026 15:16:24 -0700 Subject: [PATCH 1/4] Add ArtifactStore and extraction layer for in-memory artifact search Introduces DuckDB-backed store, row extractors for all 4 artifact types (manifest, catalog, run_results, sources), table DDL definitions, and error hierarchy for the ARTIFACT_SEARCH toolset (PR 2 of 3). --- .../Under the Hood-20260514-151517.yaml | 3 + pyproject.toml | 1 + .../dbt_admin/run_artifacts/extractors.py | 540 ++++++++++++++++++ src/dbt_mcp/dbt_admin/run_artifacts/store.py | 398 +++++++++++++ src/dbt_mcp/dbt_admin/run_artifacts/tables.py | 358 ++++++++++++ src/dbt_mcp/errors/__init__.py | 33 +- src/dbt_mcp/errors/artifact_search.py | 21 + src/dbt_mcp/errors/classification.py | 43 ++ .../unit/dbt_admin/run_artifacts/__init__.py | 0 .../dbt_admin/run_artifacts/test_store.py | 465 +++++++++++++++ uv.lock | 26 +- 11 files changed, 1867 insertions(+), 21 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20260514-151517.yaml create mode 100644 src/dbt_mcp/dbt_admin/run_artifacts/extractors.py create mode 100644 src/dbt_mcp/dbt_admin/run_artifacts/store.py create mode 100644 src/dbt_mcp/dbt_admin/run_artifacts/tables.py create mode 100644 src/dbt_mcp/errors/artifact_search.py create mode 100644 src/dbt_mcp/errors/classification.py create mode 100644 tests/unit/dbt_admin/run_artifacts/__init__.py create mode 100644 tests/unit/dbt_admin/run_artifacts/test_store.py diff --git a/.changes/unreleased/Under the Hood-20260514-151517.yaml b/.changes/unreleased/Under the Hood-20260514-151517.yaml new file mode 100644 index 000000000..3cdce6351 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20260514-151517.yaml @@ -0,0 +1,3 @@ +kind: Under the Hood +body: Add in-memory DuckDB artifact store with extraction layer for structured artifact querying +time: 2026-05-14T15:15:17.615904-07:00 diff --git a/pyproject.toml b/pyproject.toml index aa1c2ada7..dae08d180 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "filelock~=3.20.3", "starlette~=0.50.0", "dbt-artifacts-parser>=0.13.2", + "duckdb>=1.5.2", ] [tool.uv] exclude-newer = "7 days" diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py b/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py new file mode 100644 index 000000000..3bd8189ff --- /dev/null +++ b/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py @@ -0,0 +1,540 @@ +"""Extraction functions that convert plain dicts to DuckDB row tuples. + +Each function accepts a ``dict[str, Any]`` returned by ``ARTIFACT_PARSERS`` and +returns a dict mapping table names to lists of row tuples. + +``ARTIFACT_PARSERS`` guarantees the input is always a plain Python dict — either +a ``model_dump(mode="json")`` result from dbt-artifacts-parser (happy path) or the +raw artifact JSON (fallback for dbt Fusion / preview builds). Both have the same +key structure so all field access uses ``.get()`` with safe defaults. + +``ArtifactType`` is re-exported here so callers can import it from one place. +""" + +import json +import logging +from collections.abc import Callable +from typing import Any + +from dbt_mcp.dbt_admin.run_artifacts.artifacts.parsers import ArtifactType + +logger = logging.getLogger(__name__) + + +# ── Helpers ───────────────────────────────────────────────────────────── + + +def _json(data: Any) -> str: + """Serialize ``data`` to a JSON string; returns empty string for falsy values.""" + if data is None: + return "" + if isinstance(data, str): + return json.dumps(data) if data else "" + if isinstance(data, (int, float, bool)): + return json.dumps(data) + if not data: + return "" + try: + return json.dumps(data) + except Exception: + logger.debug(f"Failed to JSON-serialize artifact data: {type(data).__name__}") + return "" + + +def _owner_email_str(email: Any) -> str: + """Normalize an owner email to a plain string (groups may use ``list[str]``).""" + if isinstance(email, list): + return ", ".join(str(e) for e in email) + return str(email) if email else "" + + +# ── Manifest extraction ───────────────────────────────────────────────── + + +def _map_node(idx: int, node: dict[str, Any]) -> tuple: + """Map a manifest node or source dict to a ``nodes`` table row.""" + config = node.get("config") or {} + depends_on = node.get("depends_on") or {} + contract = node.get("contract") or {} + docs = node.get("docs") or {} + checksum = node.get("checksum") or {} + + checksum_str = ( + checksum.get("checksum", "") + if isinstance(checksum, dict) + else (str(checksum) if checksum else "") + ) + node_enabled = node.get("enabled") + config_enabled = config.get("enabled") + unique_key = config.get("unique_key") + + return ( + idx, + node.get("unique_id", ""), + node.get("name", ""), + node.get("resource_type", ""), + node.get("package_name", ""), + node.get("path") or node.get("file_path") or "", + node.get("original_file_path") or "", + _json(node.get("fqn") or []), + node.get("alias") or "", + checksum_str, + node.get("description") or "", + node.get("language") or "", + node.get("raw_code") or node.get("raw_sql") or "", + node.get("database") or "", + node.get("schema") or "", # JSON key is "schema", not "schema_" + node.get("relation_name") or "", + node.get("identifier") or node.get("alias") or "", + node_enabled if node_enabled is not None else config_enabled, + node.get("materialized") or config.get("materialized") or "", + config.get("incremental_strategy") or "", + config.get("on_schema_change") or "", + _json(unique_key) if unique_key else "", + config.get("full_refresh"), + _json(config), + node.get("access") or "", + node.get("group") or "", + contract.get("enforced"), + str(node.get("version")) if node.get("version") is not None else "", + str(node.get("latest_version")) + if node.get("latest_version") is not None + else "", + node.get("deprecation_date") or None, + _json(node.get("constraints") or []), + _json(node.get("tags") or []), + _json(node.get("meta") or {}), + node.get("source_name") or "", + node.get("source_description") or "", + node.get("loader") or "", + node.get("loaded_at_field") or "", + _json(node.get("freshness")), + node.get("compiled_code") or node.get("compiled_sql") or "", + node.get("compiled_path") or "", + _json(node.get("extra_ctes") or []), + node.get("patch_path") or "", + docs.get("show"), + _json(node.get("quoting") or {}), + _json(depends_on.get("nodes") or []), + _json(depends_on.get("macros") or []), + ) + + +def _extract_node_columns(node: dict[str, Any]) -> list[tuple]: + """Extract column rows from a manifest node dict.""" + rows = [] + columns = node.get("columns") or {} + for idx, (col_name, col) in enumerate(columns.items()): + if not isinstance(col, dict): + continue + rows.append( + ( + node.get("unique_id", ""), + col.get("name") or col_name, + idx, + col.get("data_type") or col.get("type") or "", + None, # catalog_type — filled later by catalog merge + None, # data_type resolved + col.get("description") or "", + _json(col.get("tags") or []), + _json(col.get("meta") or {}), + _json(col.get("tests") or []), + None, # catalog_comment + ) + ) + return rows + + +def _extract_edges(node: dict[str, Any]) -> list[tuple]: + """Extract dependency edges from a manifest node dict.""" + depends_on = node.get("depends_on") or {} + dep_nodes = depends_on.get("nodes") or [] + unique_id = node.get("unique_id", "") + return [(parent_id, unique_id, "ref") for parent_id in dep_nodes] + + +def _extract_test_metadata(node: dict[str, Any]) -> tuple | None: + """Extract test metadata if this is a test node.""" + if node.get("resource_type") != "test": + return None + tm = node.get("test_metadata") + if not tm or not isinstance(tm, dict): + return None + config = node.get("config") or {} + depends_on = node.get("depends_on") or {} + dep_nodes = depends_on.get("nodes") or [] + attached = next((n for n in dep_nodes if not n.startswith("test.")), "") + kwargs = tm.get("kwargs") or {} + + return ( + node.get("unique_id", ""), + tm.get("name", ""), + tm.get("namespace"), + _json(kwargs), + kwargs.get("column_name", "") if isinstance(kwargs, dict) else "", + attached, + config.get("severity") or "", + config.get("warn_if") or "", + config.get("error_if") or "", + config.get("fail_calc") or "", + config.get("store_failures"), + config.get("store_failures_as") or "", + ) + + +def _map_exposure(idx: int, exp: dict[str, Any]) -> tuple: + """Map an exposure dict to a row.""" + depends_on = exp.get("depends_on") or {} + owner = exp.get("owner") or {} + return ( + idx, + exp.get("unique_id", ""), + exp.get("name", ""), + exp.get("type"), + exp.get("label") or "", + owner.get("name") or "", + _owner_email_str(owner.get("email") or ""), + exp.get("url") or "", + exp.get("maturity") or "", + exp.get("description") or "", + exp.get("package_name") or "", + exp.get("path") or "", + exp.get("original_file_path") or "", + _json(exp.get("fqn") or []), + _json(depends_on.get("nodes") or []), + _json(depends_on.get("macros") or []), + _json(exp.get("tags") or []), + _json(exp.get("meta") or {}), + _json(exp.get("config") or {}), + ) + + +def _map_metric(idx: int, metric: dict[str, Any]) -> tuple: + """Map a metric dict to a row.""" + depends_on = metric.get("depends_on") or {} + type_params = metric.get("type_params") or {} + measure = type_params.get("measure") or {} if isinstance(type_params, dict) else {} + semantic_model_name = measure.get("name", "") if isinstance(measure, dict) else "" + return ( + idx, + metric.get("unique_id", ""), + metric.get("name", ""), + metric.get("label") or "", + metric.get("type") or metric.get("calculation_method") or "", + metric.get("description") or "", + metric.get("package_name") or "", + metric.get("path") or "", + metric.get("original_file_path") or "", + _json(metric.get("fqn") or []), + _json(type_params), + metric.get("time_granularity") or "", + semantic_model_name, + _json(depends_on.get("nodes") or []), + _json(depends_on.get("macros") or []), + metric.get("group") or "", + _json(metric.get("tags") or []), + _json(metric.get("meta") or {}), + _json(metric.get("config") or {}), + ) + + +def _map_group(idx: int, group: dict[str, Any]) -> tuple: + """Map a group dict to a row.""" + owner = group.get("owner") or {} + return ( + idx, + group.get("unique_id", ""), + group.get("name", ""), + group.get("description") or "", + group.get("package_name") or "", + group.get("path") or "", + group.get("original_file_path") or "", + owner.get("name") or "", + _owner_email_str(owner.get("email") or ""), + ) + + +def _map_macro(idx: int, macro: dict[str, Any]) -> tuple: + """Map a macro dict to a row.""" + depends_on = macro.get("depends_on") or {} + return ( + idx, + macro.get("unique_id", ""), + macro.get("name", ""), + macro.get("package_name", ""), + macro.get("path") or "", + macro.get("original_file_path") or "", + macro.get("macro_sql") or "", + macro.get("description") or "", + _json(depends_on.get("macros") or []), + _json(macro.get("arguments") or []), + _json(macro.get("meta") or {}), + ) + + +def extract_from_manifest(data: dict[str, Any]) -> dict[str, list[tuple]]: + """Extract all tables from a manifest dict.""" + nodes = data.get("nodes") or {} + sources = data.get("sources") or {} + exposures = data.get("exposures") or {} + metrics = data.get("metrics") or {} + groups = data.get("groups") or {} + macros = data.get("macros") or {} + + all_nodes = list(nodes.values()) + list(sources.values()) + + node_rows: list[tuple] = [] + column_rows: list[tuple] = [] + edge_rows: list[tuple] = [] + test_rows: list[tuple] = [] + + for idx, node in enumerate(all_nodes): + if not isinstance(node, dict): + continue + node_rows.append(_map_node(idx, node)) + column_rows.extend(_extract_node_columns(node)) + edge_rows.extend(_extract_edges(node)) + tm = _extract_test_metadata(node) + if tm: + test_rows.append(tm) + + # Exposure → model edges + for exp in exposures.values(): + if not isinstance(exp, dict): + continue + dep_nodes = (exp.get("depends_on") or {}).get("nodes") or [] + exp_uid = exp.get("unique_id", "") + for parent_id in dict.fromkeys(dep_nodes): + edge_rows.append((parent_id, exp_uid, "exposure_ref")) + + # Metric → model edges + for metric in metrics.values(): + if not isinstance(metric, dict): + continue + dep_nodes = (metric.get("depends_on") or {}).get("nodes") or [] + metric_uid = metric.get("unique_id", "") + for parent_id in dict.fromkeys(dep_nodes): + edge_rows.append((parent_id, metric_uid, "metric_ref")) + + # Prepend sequential ids to tables that extracted without them + column_rows = [(i, *row) for i, row in enumerate(column_rows)] + edge_rows = [(i, *row) for i, row in enumerate(edge_rows)] + test_rows = [(i, *row) for i, row in enumerate(test_rows)] + + exposure_rows = [ + _map_exposure(i, e) + for i, e in enumerate(exposures.values()) + if isinstance(e, dict) + ] + metric_rows = [ + _map_metric(i, m) for i, m in enumerate(metrics.values()) if isinstance(m, dict) + ] + group_rows = [ + _map_group(i, g) for i, g in enumerate(groups.values()) if isinstance(g, dict) + ] + macro_rows = [ + _map_macro(i, m) for i, m in enumerate(macros.values()) if isinstance(m, dict) + ] + + return { + "nodes": node_rows, + "node_columns": column_rows, + "edges": edge_rows, + "test_metadata": test_rows, + "exposures": exposure_rows, + "metrics": metric_rows, + "groups": group_rows, + "macros": macro_rows, + } + + +# ── Catalog extraction ────────────────────────────────────────────────── + + +def extract_from_catalog(data: dict[str, Any]) -> dict[str, list[tuple]]: + """Extract tables from a catalog dict.""" + nodes = data.get("nodes") or {} + sources = data.get("sources") or {} + # Iterate over (unique_id, entry) pairs — unique_id is the dict key, not a field in the value + all_entries: list[tuple[str, dict[str, Any]]] = [ + (uid, entry) + for uid, entry in list(nodes.items()) + list(sources.items()) + if isinstance(entry, dict) + ] + + table_rows: list[tuple] = [] + stat_rows: list[tuple] = [] + column_updates: list[tuple] = [] + stat_idx = 0 + + for idx, (unique_id, entry) in enumerate(all_entries): + metadata = entry.get("metadata") or {} + table_rows.append( + ( + idx, + unique_id, + metadata.get("type"), + metadata.get("database"), + metadata.get("schema"), # JSON key is "schema" + metadata.get("name"), + metadata.get("owner"), + metadata.get("comment") or "", + ) + ) + + stats = entry.get("stats") or {} + for stat_id, stat in stats.items(): + if not isinstance(stat, dict): + continue + stat_rows.append( + ( + stat_idx, + unique_id, + stat.get("id") or stat_id, + stat.get("label"), + str(stat.get("value", "")), + stat.get("description") or "", + stat.get("include"), + ) + ) + stat_idx += 1 + + columns = entry.get("columns") or {} + for col_name, col in columns.items(): + if not isinstance(col, dict): + continue + column_updates.append( + ( + unique_id, + col.get("name") or col_name, + col.get("index"), + col.get("type"), + col.get("comment") or "", + ) + ) + + return { + "catalog_tables": table_rows, + "catalog_stats": stat_rows, + "_node_columns_update": column_updates, + } + + +# ── Run results extraction ────────────────────────────────────────────── + + +def extract_from_run_results(data: dict[str, Any]) -> dict[str, list[tuple]]: + """Extract tables from a run_results dict.""" + metadata = data.get("metadata") or {} + invocation_id = metadata.get("invocation_id", "") + args = data.get("args") or {} + results = data.get("results") or [] + + which = args.get("which", "") or args.get("command", "") or "" + select = args.get("select", "") or "" + + invocation_rows: list[tuple] = [ + ( + 0, + invocation_id, + which, + select, + metadata.get("dbt_version", ""), + metadata.get("generated_at", ""), + data.get("elapsed_time", 0.0), + _json(args), + len(results), + ) + ] + + result_rows: list[tuple] = [] + for idx, result in enumerate(results): + if not isinstance(result, dict): + continue + status_str = (result.get("status") or "").lower() + timing = result.get("timing") or [] + adapter_response = result.get("adapter_response") or {} + + result_rows.append( + ( + idx, + result.get("unique_id", ""), + invocation_id, + status_str, + result.get("execution_time", 0.0), + result.get("thread_id", ""), + result.get("message") or "", + result.get("relation_name") or "", + _json(adapter_response), + _json(timing), + ) + ) + + return { + "invocations": invocation_rows, + "run_results": result_rows, + } + + +# ── Sources extraction ────────────────────────────────────────────────── + + +def extract_from_sources(data: dict[str, Any]) -> dict[str, list[tuple]]: + """Extract tables from a sources (freshness) dict.""" + metadata = data.get("metadata") or {} + invocation_id = metadata.get("invocation_id", "") + results = data.get("results") or [] + + rows: list[tuple] = [] + for idx, result in enumerate(results): + if not isinstance(result, dict): + continue + unique_id = result.get("unique_id", "") + parts = unique_id.split(".") + criteria = result.get("criteria") or {} + status_str = (result.get("status") or "").lower() + + rows.append( + ( + idx, + unique_id, + parts[2] if len(parts) > 2 else "", + parts[3] if len(parts) > 3 else "", + invocation_id, + status_str, + result.get("max_loaded_at") or "", + result.get("snapshotted_at") or "", + result.get("max_loaded_at_time_ago_in_s") or 0.0, + result.get("execution_time", 0.0), + result.get("thread_id") or "", + result.get("error") or "", + criteria.get("warn_after", {}).get("count") + if isinstance(criteria, dict) + else None, + criteria.get("warn_after", {}).get("period", "") + if isinstance(criteria, dict) + else "", + criteria.get("error_after", {}).get("count") + if isinstance(criteria, dict) + else None, + criteria.get("error_after", {}).get("period", "") + if isinstance(criteria, dict) + else "", + _json(result.get("adapter_response") or {}), + _json(result.get("timing") or []), + ) + ) + + return {"source_freshness": rows} + + +# ── Pipeline mappings ─────────────────────────────────────────────────── + +ARTIFACT_EXTRACTORS: dict[ + ArtifactType, Callable[[dict[str, Any]], dict[str, list[tuple]]] +] = { + ArtifactType.MANIFEST: extract_from_manifest, + ArtifactType.CATALOG: extract_from_catalog, + ArtifactType.RUN_RESULTS: extract_from_run_results, + ArtifactType.SOURCES: extract_from_sources, +} diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/store.py b/src/dbt_mcp/dbt_admin/run_artifacts/store.py new file mode 100644 index 000000000..9aa07c24b --- /dev/null +++ b/src/dbt_mcp/dbt_admin/run_artifacts/store.py @@ -0,0 +1,398 @@ +"""In-memory DuckDB artifact store. + +Artifacts are loaded via `load_artifact()` from parsed dicts +fetched by the Admin API client. +""" + +import logging +import time +from typing import Any + +import duckdb + +from dbt_mcp.dbt_admin.run_artifacts.artifacts.parsers import ( + ArtifactType, + ARTIFACT_PARSERS, +) +from dbt_mcp.dbt_admin.run_artifacts.extractors import ARTIFACT_EXTRACTORS +from dbt_mcp.dbt_admin.run_artifacts.tables import TABLES, TableConfig +from dbt_mcp.errors.artifact_search import ( + ArtifactNotLoadedError, + ArtifactQueryError, + ArtifactValidationError, +) + +logger = logging.getLogger(__name__) + +READONLY_BLOCKED = frozenset( + { + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "CREATE", + "ALTER", + "TRUNCATE", + "REPLACE", + "MERGE", + "UPSERT", + "COPY", + "LOAD", + "INSTALL", + "ATTACH", + "DETACH", + "EXPORT", + "IMPORT", + "VACUUM", + "PRAGMA", + } +) + +MAX_RESULT_ROWS = 500 + + +class ArtifactStore: + """Manages an in-memory DuckDB database loaded with dbt artifacts.""" + + def __init__(self) -> None: + self.conn = duckdb.connect() + self.conn.execute("INSTALL fts;") + self.conn.execute("LOAD fts;") + self._loaded_tables: set[str] = set() + self._tables_created: bool = False + self._pending_index_tables: set[str] = set() + + # Methods for loading + + def load_artifact( + self, + run_id: int, + artifact_type: ArtifactType, + raw_data: dict[str, Any], + *, + reindex: bool = True, + ) -> dict[str, Any]: + """Validate, extract, and load a single artifact into DuckDB. + + Assumes the store has already been cleared by the caller (load_artifacts + calls reset() before the first artifact). Returns + ``{"tables": {table: row_count}, "timing": {phase: ms}}``. + """ + t_wall = time.perf_counter() + self._ensure_tables_created() + + # Phase 1: parsing (dbt-artifacts-parser handles schema version dispatch) + t = time.perf_counter() + try: + parsed = ARTIFACT_PARSERS[artifact_type](raw_data) + except Exception as e: + raise ArtifactValidationError( + f"Parsing failed for {artifact_type.value}: {e}" + ) from e + validate_ms = round((time.perf_counter() - t) * 1000) + + # Phase 2: row extraction + extractor = ARTIFACT_EXTRACTORS[artifact_type] + t = time.perf_counter() + tables_data = extractor(parsed) + extract_ms = round((time.perf_counter() - t) * 1000) + + # Phase 3: DuckDB inserts + counts: dict[str, int] = {} + affected_tables: set[str] = set() + t = time.perf_counter() + + for table_name, rows in tables_data.items(): + if not rows: + continue + + # Special: catalog column merge into node_columns + if table_name == "_node_columns_update": + self._merge_node_columns(rows, run_id) + affected_tables.add("node_columns") + row = self.conn.execute( + "SELECT COUNT(*) FROM node_columns WHERE run_id = ?", [run_id] + ).fetchone() + counts["node_columns"] = row[0] if row else 0 + continue + + config = TABLES[table_name] + self._insert_rows(config, rows, run_id) + counts[table_name] = len(rows) + self._loaded_tables.add(table_name) + affected_tables.add(table_name) + logger.info(f"Loaded {len(rows)} rows into {table_name}") + + insert_ms = round((time.perf_counter() - t) * 1000) + + # Phase 4: index building (deferred or immediate) + t = time.perf_counter() + if reindex: + for table_name in affected_tables: + if table_name in TABLES: + self._build_indexes(TABLES[table_name]) + else: + self._pending_index_tables |= affected_tables + index_ms = round((time.perf_counter() - t) * 1000) + + total_ms = round((time.perf_counter() - t_wall) * 1000) + + return { + "tables": counts, + "timing": { + "validate_ms": validate_ms, + "extract_ms": extract_ms, + "insert_ms": insert_ms, + "index_ms": index_ms, + "total_ms": total_ms, + }, + } + + def build_all_indexes(self) -> list[str]: + """Build indexes for all tables pending index construction. + + Call this once after loading multiple artifacts with ``reindex=False``. + Returns the list of table names that were indexed. + """ + to_index = set(self._pending_index_tables) + indexed = [] + for table_name in to_index: + if table_name in TABLES: + self._build_indexes(TABLES[table_name]) + indexed.append(table_name) + self._pending_index_tables.clear() + return indexed + + def reset(self) -> dict[str, int]: + """Drop all tables and reset the store to empty state. + + Returns ``{table_name: rows_dropped}`` for the caller's log. + The FTS extension stays loaded. Call ``load_artifact`` to repopulate. + """ + dropped: dict[str, int] = {} + if self._tables_created: + for table_name in TABLES: + row = self.conn.execute( + f'SELECT COUNT(*) FROM "{table_name}"' + ).fetchone() + dropped[table_name] = row[0] if row else 0 + self.conn.execute(f'DROP TABLE IF EXISTS "{table_name}"') + self._loaded_tables.clear() + self._pending_index_tables.clear() + self._tables_created = False + return dropped + + def _ensure_tables_created(self) -> None: + """Lazily create all tables on first load.""" + if self._tables_created: + return + for config in TABLES.values(): + self.conn.execute(f'DROP TABLE IF EXISTS "{config.table_name}";') + self.conn.execute(config.table_ddl) + self._tables_created = True + + def _insert_rows(self, config: TableConfig, rows: list[tuple], run_id: int) -> None: + """Bulk insert rows into a table, injecting run_id as the second column.""" + if not rows: + return + # rows are (id, col1, col2, ...) — insert as (id, run_id, col1, col2, ...) + tagged = [(row[0], run_id, *row[1:]) for row in rows] + placeholders = ", ".join(["?"] * len(tagged[0])) + self.conn.executemany( + f"INSERT INTO {config.table_name} VALUES ({placeholders})", tagged + ) + + def _merge_node_columns(self, rows: list[tuple], run_id: int) -> None: + """Merge catalog column data into node_columns via batch update. + + Each row is (unique_id, column_name, column_index, catalog_type, catalog_comment). + Uses a temp table + batch UPDATE/INSERT for performance. + Only touches rows with the given run_id. + """ + if not rows: + return + + self.conn.execute(""" + CREATE OR REPLACE TEMP TABLE _catalog_cols ( + unique_id VARCHAR, column_name VARCHAR, + column_index INTEGER, catalog_type VARCHAR, + catalog_comment TEXT + ) + """) + try: + placeholders = ", ".join(["?"] * 5) + self.conn.executemany( + f"INSERT INTO _catalog_cols VALUES ({placeholders})", rows + ) + + self.conn.execute( + """ + UPDATE node_columns SET + column_index = COALESCE(cc.column_index, node_columns.column_index), + catalog_type = COALESCE(cc.catalog_type, node_columns.catalog_type), + data_type = COALESCE(cc.catalog_type, node_columns.data_type), + catalog_comment = COALESCE(cc.catalog_comment, node_columns.catalog_comment) + FROM _catalog_cols cc + WHERE node_columns.unique_id = cc.unique_id + AND node_columns.column_name = cc.column_name + AND node_columns.run_id = ? + """, + [run_id], + ) + + row = self.conn.execute( + "SELECT COALESCE(MAX(id), -1) + 1 FROM node_columns" + ).fetchone() + next_id = row[0] if row else 0 + self.conn.execute( + """ + INSERT INTO node_columns (id, run_id, unique_id, column_name, column_index, + catalog_type, data_type, catalog_comment) + SELECT ? + row_number() OVER () - 1, + ?, + cc.unique_id, cc.column_name, cc.column_index, + cc.catalog_type, cc.catalog_type, cc.catalog_comment + FROM _catalog_cols cc + WHERE NOT EXISTS ( + SELECT 1 FROM node_columns nc + WHERE nc.unique_id = cc.unique_id + AND nc.column_name = cc.column_name + AND nc.run_id = ? + ) + """, + [next_id, run_id, run_id], + ) + finally: + self.conn.execute("DROP TABLE IF EXISTS _catalog_cols") + self._loaded_tables.add("node_columns") + logger.info(f"Merged {len(rows)} catalog columns into node_columns") + + def _build_indexes(self, config: TableConfig) -> None: + """Build FTS and B-tree indexes for a table.""" + if config.fts_columns: + fts_cols = ", ".join(f"'{c}'" for c in config.fts_columns) + self.conn.execute( + f"PRAGMA create_fts_index('{config.table_name}', 'id', " + f"{fts_cols}, overwrite=1);" + ) + + for col in config.index_columns: + idx_name = f"idx_{config.table_name[:4]}_{col}" + self.conn.execute( + f"CREATE INDEX IF NOT EXISTS {idx_name} " + f'ON {config.table_name}("{col}");' + ) + + # Methods to query + + @property + def is_loaded(self) -> bool: + """Whether any artifact tables have been loaded.""" + return bool(self._loaded_tables) + + def list_tables(self) -> list[dict[str, Any]]: + """List all loaded artifact tables with row counts.""" + if not self._tables_created: + return [] + + result = self.conn.execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'main'" + ).fetchall() + + tables = [] + for (table_name,) in result: + if table_name.startswith("fts_"): + continue + count_row = self.conn.execute( + f'SELECT COUNT(*) FROM "{table_name}"' + ).fetchone() + row_count = count_row[0] if count_row else 0 + tables.append( + { + "table_name": table_name, + "row_count": row_count, + "status": "loaded" + if table_name in self._loaded_tables + else "not_loaded", + } + ) + return tables + + def describe_table(self, table_name: str) -> list[dict[str, str]]: + """Describe the schema of a loaded table.""" + self._validate_table_name(table_name) + rows = self.conn.execute(f'DESCRIBE "{table_name}"').fetchall() + return [{"column_name": row[0], "column_type": row[1]} for row in rows] + + def query(self, sql: str) -> list[dict[str, Any]]: + """Execute a read-only SQL query. Results capped at 500 rows.""" + tokens = sql.strip().upper().split() + for token in tokens: + if token in READONLY_BLOCKED: + raise ArtifactQueryError( + f"Only read-only queries are allowed. Blocked keyword: {token}" + ) + + try: + result = self.conn.execute(sql) + columns = [desc[0] for desc in result.description] + rows = result.fetchmany(MAX_RESULT_ROWS) + return [ + {col: _serialize(val) for col, val in zip(columns, row)} for row in rows + ] + except duckdb.Error as e: + raise ArtifactQueryError(f"Query failed: {e}") from e + + def search( + self, *, table_name: str, query_text: str, limit: int = 20 + ) -> list[dict[str, Any]]: + """Full-text BM25 search on a loaded table.""" + self._validate_table_name(table_name) + config = TABLES.get(table_name) + if not config or not config.fts_columns: + raise ArtifactQueryError( + f"Table '{table_name}' does not support full-text search" + ) + limit = min(limit, MAX_RESULT_ROWS) + + fts_table = f"fts_main_{table_name}" + sql = f""" + SELECT t.*, {fts_table}.match_bm25(t.id, ?) AS fts_score + FROM "{table_name}" t + WHERE {fts_table}.match_bm25(t.id, ?) IS NOT NULL + ORDER BY fts_score DESC + LIMIT ? + """ + try: + result = self.conn.execute(sql, [query_text, query_text, limit]) + columns = [desc[0] for desc in result.description] + rows = result.fetchall() + return [ + {col: _serialize(val) for col, val in zip(columns, row)} for row in rows + ] + except duckdb.Error as e: + raise ArtifactQueryError(f"Search failed: {e}") from e + + def close(self) -> None: + """Close the DuckDB connection.""" + if self.conn: + self.conn.close() + logger.info("In-memory database closed") + + # Helper + + def _validate_table_name(self, table_name: str) -> None: + """Check table_name is a known loaded table (also prevents SQL injection).""" + if table_name not in self._loaded_tables: + available = ", ".join(sorted(self._loaded_tables)) or "(none)" + raise ArtifactNotLoadedError( + f"Unknown table '{table_name}'. Available: {available}" + ) + + +def _serialize(val: Any) -> Any: + """Ensure values are JSON-serializable.""" + if val is None or isinstance(val, (str, int, float, bool)): + return val + return str(val) diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/tables.py b/src/dbt_mcp/dbt_admin/run_artifacts/tables.py new file mode 100644 index 000000000..6f87f91eb --- /dev/null +++ b/src/dbt_mcp/dbt_admin/run_artifacts/tables.py @@ -0,0 +1,358 @@ +"""DuckDB table definitions for dbt artifacts.""" + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class TableConfig: + """Schema definition for a DuckDB table.""" + + table_name: str + table_ddl: str + fts_columns: list[str] = field(default_factory=list) + index_columns: list[str] = field(default_factory=list) + + +NODES = TableConfig( + table_name="nodes", + table_ddl=""" + CREATE TABLE nodes ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + name VARCHAR, + resource_type VARCHAR, + package_name VARCHAR, + file_path VARCHAR, + original_file_path VARCHAR, + fqn TEXT, + alias VARCHAR, + checksum VARCHAR, + description TEXT, + node_language VARCHAR, + raw_code TEXT, + database_name VARCHAR, + schema_name VARCHAR, + relation_name VARCHAR, + identifier VARCHAR, + enabled BOOLEAN, + materialized VARCHAR, + incremental_strategy VARCHAR, + on_schema_change VARCHAR, + unique_key TEXT, + full_refresh BOOLEAN, + config TEXT, + access_level VARCHAR, + group_name VARCHAR, + contract_enforced BOOLEAN, + version VARCHAR, + latest_version VARCHAR, + deprecation_date VARCHAR, + constraints TEXT, + tags TEXT, + meta TEXT, + source_name VARCHAR, + source_description TEXT, + loader VARCHAR, + loaded_at_field VARCHAR, + freshness TEXT, + compiled_code TEXT, + compiled_path VARCHAR, + extra_ctes TEXT, + patch_path VARCHAR, + docs_show BOOLEAN, + quoting TEXT, + depends_on_nodes TEXT, + depends_on_macros TEXT + ) + """, + fts_columns=["name", "description", "raw_code", "compiled_code"], + index_columns=["unique_id", "name", "resource_type", "package_name"], +) + +NODE_COLUMNS = TableConfig( + table_name="node_columns", + table_ddl=""" + CREATE TABLE node_columns ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + column_name VARCHAR NOT NULL, + column_index INTEGER, + declared_type VARCHAR, + catalog_type VARCHAR, + data_type VARCHAR, + description TEXT, + tags TEXT, + meta TEXT, + tests TEXT, + catalog_comment TEXT, + UNIQUE (run_id, unique_id, column_name) + ) + """, + fts_columns=[], + index_columns=["unique_id"], +) + +EDGES = TableConfig( + table_name="edges", + table_ddl=""" + CREATE TABLE edges ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + parent_unique_id VARCHAR NOT NULL, + child_unique_id VARCHAR NOT NULL, + edge_type VARCHAR + ) + """, + index_columns=["parent_unique_id", "child_unique_id"], +) + +TEST_METADATA = TableConfig( + table_name="test_metadata", + table_ddl=""" + CREATE TABLE test_metadata ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + test_name VARCHAR, + test_namespace VARCHAR, + kwargs TEXT, + column_name VARCHAR, + attached_node VARCHAR, + severity VARCHAR, + warn_if VARCHAR, + error_if VARCHAR, + fail_calc VARCHAR, + store_failures BOOLEAN, + store_failures_as VARCHAR + ) + """, + index_columns=["unique_id", "test_name", "attached_node"], +) + +EXPOSURES = TableConfig( + table_name="exposures", + table_ddl=""" + CREATE TABLE exposures ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + name VARCHAR, + exposure_type VARCHAR, + label VARCHAR, + owner_name VARCHAR, + owner_email VARCHAR, + url VARCHAR, + maturity VARCHAR, + description TEXT, + package_name VARCHAR, + file_path VARCHAR, + original_file_path VARCHAR, + fqn TEXT, + depends_on_nodes TEXT, + depends_on_macros TEXT, + tags TEXT, + meta TEXT, + config TEXT + ) + """, + fts_columns=["name", "description"], + index_columns=["unique_id", "name"], +) + +METRICS = TableConfig( + table_name="metrics", + table_ddl=""" + CREATE TABLE metrics ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + name VARCHAR, + label VARCHAR, + metric_type VARCHAR, + description TEXT, + package_name VARCHAR, + file_path VARCHAR, + original_file_path VARCHAR, + fqn TEXT, + type_params TEXT, + time_granularity VARCHAR, + semantic_model_name VARCHAR, + depends_on_nodes TEXT, + depends_on_macros TEXT, + group_name VARCHAR, + tags TEXT, + meta TEXT, + config TEXT + ) + """, + fts_columns=["name", "label", "description"], + index_columns=["unique_id", "name"], +) + +GROUPS = TableConfig( + table_name="groups", + table_ddl=""" + CREATE TABLE groups ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + name VARCHAR, + description TEXT, + package_name VARCHAR, + file_path VARCHAR, + original_file_path VARCHAR, + owner_name VARCHAR, + owner_email VARCHAR + ) + """, + index_columns=["unique_id", "name"], +) + +MACROS = TableConfig( + table_name="macros", + table_ddl=""" + CREATE TABLE macros ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + name VARCHAR, + package_name VARCHAR, + file_path VARCHAR, + original_file_path VARCHAR, + macro_sql TEXT, + description TEXT, + depends_on_macros TEXT, + arguments TEXT, + meta TEXT + ) + """, + fts_columns=["name", "description", "macro_sql"], + index_columns=["unique_id", "name", "package_name"], +) + +CATALOG_TABLES = TableConfig( + table_name="catalog_tables", + table_ddl=""" + CREATE TABLE catalog_tables ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + table_type VARCHAR, + database_name VARCHAR, + schema_name VARCHAR, + table_name VARCHAR, + table_owner VARCHAR, + table_comment TEXT + ) + """, + fts_columns=["table_name", "table_comment"], + index_columns=["unique_id", "table_name"], +) + +CATALOG_STATS = TableConfig( + table_name="catalog_stats", + table_ddl=""" + CREATE TABLE catalog_stats ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + stat_id VARCHAR, + stat_label VARCHAR, + stat_value VARCHAR, + description TEXT, + include_in_stats BOOLEAN + ) + """, + index_columns=["unique_id"], +) + +INVOCATIONS = TableConfig( + table_name="invocations", + table_ddl=""" + CREATE TABLE invocations ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + invocation_id VARCHAR NOT NULL, + command VARCHAR, + selector VARCHAR, + dbt_version VARCHAR, + generated_at VARCHAR, + elapsed_time FLOAT, + args TEXT, + node_count INTEGER + ) + """, + index_columns=["invocation_id"], +) + +RUN_RESULTS = TableConfig( + table_name="run_results", + table_ddl=""" + CREATE TABLE run_results ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + invocation_id VARCHAR, + status VARCHAR, + execution_time FLOAT, + thread_id VARCHAR, + message TEXT, + relation_name VARCHAR, + adapter_response TEXT, + timing TEXT + ) + """, + fts_columns=["status", "message"], + index_columns=["unique_id", "invocation_id", "status"], +) + +SOURCE_FRESHNESS = TableConfig( + table_name="source_freshness", + table_ddl=""" + CREATE TABLE source_freshness ( + id INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + unique_id VARCHAR NOT NULL, + source_name VARCHAR, + table_name VARCHAR, + invocation_id VARCHAR, + status VARCHAR, + max_loaded_at VARCHAR, + snapshotted_at VARCHAR, + max_loaded_at_time_ago FLOAT, + execution_time FLOAT, + thread_id VARCHAR, + error TEXT, + warn_after_count INTEGER, + warn_after_period VARCHAR, + error_after_count INTEGER, + error_after_period VARCHAR, + adapter_response TEXT, + timing TEXT + ) + """, + fts_columns=["source_name", "table_name", "status"], + index_columns=["unique_id", "source_name", "status"], +) + +# All table configs, keyed by table name +TABLES: dict[str, TableConfig] = { + c.table_name: c + for c in [ + NODES, + NODE_COLUMNS, + EDGES, + TEST_METADATA, + EXPOSURES, + METRICS, + GROUPS, + MACROS, + CATALOG_TABLES, + CATALOG_STATS, + INVOCATIONS, + RUN_RESULTS, + SOURCE_FRESHNESS, + ] +} diff --git a/src/dbt_mcp/errors/__init__.py b/src/dbt_mcp/errors/__init__.py index 49aba0e13..3f4fc21d6 100644 --- a/src/dbt_mcp/errors/__init__.py +++ b/src/dbt_mcp/errors/__init__.py @@ -3,7 +3,15 @@ AdminAPIToolCallError, ArtifactRetrievalError, ) +from dbt_mcp.errors.artifact_search import ( + ArtifactLoadError, + ArtifactNotLoadedError, + ArtifactQueryError, + ArtifactSearchError, + ArtifactValidationError, +) from dbt_mcp.errors.base import ToolCallError +from dbt_mcp.errors.classification import ClientToolCallError, ServerToolCallError from dbt_mcp.errors.cli import BinaryExecutionError, CLIToolCallError from dbt_mcp.errors.common import ( ConfigurationError, @@ -17,30 +25,15 @@ ) from dbt_mcp.errors.sql import RemoteToolError, SQLToolCallError -ClientToolCallError = ( - InvalidParameterError - | NotFoundError - | SemanticLayerQueryTimeoutError - | GraphQLError -) - -ServerToolCallError = ( - SemanticLayerToolCallError - | CLIToolCallError - | BinaryExecutionError - | SQLToolCallError - | RemoteToolError - | DiscoveryToolCallError - | AdminAPIToolCallError - | AdminAPIError - | ArtifactRetrievalError - | ConfigurationError -) - __all__ = [ "AdminAPIError", "AdminAPIToolCallError", + "ArtifactLoadError", + "ArtifactNotLoadedError", + "ArtifactQueryError", "ArtifactRetrievalError", + "ArtifactSearchError", + "ArtifactValidationError", "BinaryExecutionError", "CLIToolCallError", "ConfigurationError", diff --git a/src/dbt_mcp/errors/artifact_search.py b/src/dbt_mcp/errors/artifact_search.py new file mode 100644 index 000000000..59528acb5 --- /dev/null +++ b/src/dbt_mcp/errors/artifact_search.py @@ -0,0 +1,21 @@ +from dbt_mcp.errors.base import ToolCallError + + +class ArtifactSearchError(ToolCallError): + """Base exception for server-side artifact store failures (load, query, validation).""" + + +class ArtifactLoadError(ArtifactSearchError): + """Raised when fetching or parsing an artifact from the Admin API fails.""" + + +class ArtifactQueryError(ArtifactSearchError): + """Raised when a SQL query against the artifact store fails.""" + + +class ArtifactValidationError(ArtifactSearchError): + """Raised when Pydantic validation of raw artifact JSON fails.""" + + +class ArtifactNotLoadedError(ToolCallError): + """Raised when querying before any artifacts have been loaded (client error).""" diff --git a/src/dbt_mcp/errors/classification.py b/src/dbt_mcp/errors/classification.py new file mode 100644 index 000000000..0c82f622d --- /dev/null +++ b/src/dbt_mcp/errors/classification.py @@ -0,0 +1,43 @@ +from dbt_mcp.errors.admin_api import ( + AdminAPIError, + AdminAPIToolCallError, + ArtifactRetrievalError, +) +from dbt_mcp.errors.artifact_search import ( + ArtifactNotLoadedError, + ArtifactSearchError, +) +from dbt_mcp.errors.cli import BinaryExecutionError, CLIToolCallError +from dbt_mcp.errors.common import ( + ConfigurationError, + InvalidParameterError, + NotFoundError, +) +from dbt_mcp.errors.discovery import DiscoveryToolCallError, GraphQLError +from dbt_mcp.errors.semantic_layer import ( + SemanticLayerQueryTimeoutError, + SemanticLayerToolCallError, +) +from dbt_mcp.errors.sql import RemoteToolError, SQLToolCallError + +ClientToolCallError = ( + InvalidParameterError + | NotFoundError + | SemanticLayerQueryTimeoutError + | GraphQLError + | ArtifactNotLoadedError +) + +ServerToolCallError = ( + SemanticLayerToolCallError + | CLIToolCallError + | BinaryExecutionError + | SQLToolCallError + | RemoteToolError + | DiscoveryToolCallError + | AdminAPIToolCallError + | AdminAPIError + | ArtifactRetrievalError + | ConfigurationError + | ArtifactSearchError +) diff --git a/tests/unit/dbt_admin/run_artifacts/__init__.py b/tests/unit/dbt_admin/run_artifacts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/dbt_admin/run_artifacts/test_store.py b/tests/unit/dbt_admin/run_artifacts/test_store.py new file mode 100644 index 000000000..230f8ffcd --- /dev/null +++ b/tests/unit/dbt_admin/run_artifacts/test_store.py @@ -0,0 +1,465 @@ +"""Unit tests for ArtifactStore (in-memory DuckDB artifact store). + +Two levels of coverage: +- Direct: _insert_rows with hand-crafted tuples for focused behavior tests. +- Integration: minimal raw artifact dicts through the full load_artifact pipeline. + The parser's _to_dict fallback returns raw dicts unchanged when strict validation + fails, so extractors process them via .get() with safe defaults — no mocking needed. +""" + +from collections.abc import Iterator +from unittest.mock import MagicMock, patch + +import pytest + +from dbt_mcp.dbt_admin.run_artifacts.artifacts.parsers import ArtifactType +from dbt_mcp.dbt_admin.run_artifacts.store import MAX_RESULT_ROWS, ArtifactStore +from dbt_mcp.dbt_admin.run_artifacts.tables import ( + INVOCATIONS, + NODE_COLUMNS, + RUN_RESULTS, +) +from dbt_mcp.errors.artifact_search import ( + ArtifactNotLoadedError, + ArtifactQueryError, + ArtifactValidationError, +) + +# ── Row-tuple helpers ────────────────────────────────────────────────────── +# _insert_rows expects (id, col1, ...) — run_id is injected as col 2 by the store. + + +def _invocation(idx: int = 0) -> tuple: + # id, invocation_id, command, selector, dbt_version, generated_at, elapsed_time, args, node_count + return ( + idx, + f"inv-{idx}", + "run", + None, + "1.9.0", + "2024-01-01T00:00:00", + float(idx), + None, + 0, + ) + + +def _run_result(idx: int, message: str = "ok") -> tuple: + # id, unique_id, invocation_id, status, execution_time, thread_id, message, relation_name, adapter_response, timing + return ( + idx, + f"model.pkg.node_{idx}", + "inv-0", + "pass", + 1.0, + "thread-1", + message, + None, + None, + None, + ) + + +def _node_column(idx: int, unique_id: str = "model.pkg.x", col: str = "col_a") -> tuple: + # id, unique_id, column_name, column_index, declared_type, catalog_type, + # data_type, description, tags, meta, tests, catalog_comment + return (idx, unique_id, col, None, None, None, None, None, None, None, None, None) + + +# ── Fixtures ─────────────────────────────────────────────────────────────── + + +@pytest.fixture +def store() -> Iterator[ArtifactStore]: + s = ArtifactStore() + yield s + s.close() + + +@pytest.fixture +def loaded_store(store: ArtifactStore) -> ArtifactStore: + """Store with one invocations row pre-loaded (run_id=1).""" + store._ensure_tables_created() + store._insert_rows(INVOCATIONS, [_invocation(0)], run_id=1) + store._loaded_tables.add("invocations") + return store + + +# ── Reset ────────────────────────────────────────────────────────────────── + + +class TestReset: + def test_clears_loaded_state_and_returns_row_counts( + self, loaded_store: ArtifactStore + ) -> None: + dropped = loaded_store.reset() + assert not loaded_store.is_loaded + assert not loaded_store._tables_created + assert dropped["invocations"] == 1 + + def test_tables_can_be_reloaded_after_reset( + self, loaded_store: ArtifactStore + ) -> None: + loaded_store.reset() + loaded_store._ensure_tables_created() + loaded_store._insert_rows(INVOCATIONS, [_invocation(0)], run_id=2) + loaded_store._loaded_tables.add("invocations") + rows = loaded_store.query("SELECT run_id FROM invocations") + assert rows == [{"run_id": 2}] + + +# ── Table introspection ──────────────────────────────────────────────────── + + +class TestTableIntrospection: + def test_list_tables_distinguishes_loaded_from_not_loaded( + self, loaded_store: ArtifactStore + ) -> None: + by_name = {t["table_name"]: t for t in loaded_store.list_tables()} + assert by_name["invocations"]["status"] == "loaded" + assert by_name["nodes"]["status"] == "not_loaded" + + def test_describe_unknown_table_raises(self, loaded_store: ArtifactStore) -> None: + with pytest.raises(ArtifactNotLoadedError, match="ghost_table"): + loaded_store.describe_table("ghost_table") + + +# ── Query ────────────────────────────────────────────────────────────────── + + +class TestQuery: + def test_mutating_keywords_are_blocked(self, store: ArtifactStore) -> None: + for keyword in ("INSERT", "UPDATE", "DELETE", "DROP", "CREATE"): + with pytest.raises(ArtifactQueryError, match="Blocked keyword"): + store.query(f"{keyword} something") + + def test_mutating_keywords_blocked_anywhere_in_query( + self, store: ArtifactStore + ) -> None: + with pytest.raises(ArtifactQueryError, match="Blocked keyword"): + store.query("SELECT 1; DROP TABLE nodes") + + def test_invalid_sql_raises(self, store: ArtifactStore) -> None: + with pytest.raises(ArtifactQueryError, match="Query failed"): + store.query("SELECT * FROM table_that_does_not_exist_xyz") + + def test_row_cap_at_max_result_rows(self, store: ArtifactStore) -> None: + store._ensure_tables_created() + store._insert_rows( + INVOCATIONS, + [_invocation(i) for i in range(MAX_RESULT_ROWS + 100)], + run_id=1, + ) + assert len(store.query("SELECT id FROM invocations")) == MAX_RESULT_ROWS + + +# ── Search ───────────────────────────────────────────────────────────────── + + +class TestSearch: + def test_unknown_table_raises(self, loaded_store: ArtifactStore) -> None: + with pytest.raises(ArtifactNotLoadedError): + loaded_store.search(table_name="ghost_table", query_text="foo") + + def test_table_without_fts_columns_raises(self, store: ArtifactStore) -> None: + store._ensure_tables_created() + store._loaded_tables.add("node_columns") # node_columns has no fts_columns + with pytest.raises( + ArtifactQueryError, match="does not support full-text search" + ): + store.search(table_name="node_columns", query_text="anything") + + def test_bm25_returns_matching_rows_with_score(self, store: ArtifactStore) -> None: + store._ensure_tables_created() + store._insert_rows( + RUN_RESULTS, [_run_result(0, message="compilation failure")], run_id=1 + ) + store._loaded_tables.add("run_results") + store._build_indexes(RUN_RESULTS) + + results = store.search(table_name="run_results", query_text="compilation") + assert len(results) == 1 + assert results[0]["message"] == "compilation failure" + assert "fts_score" in results[0] + + +# ── Load artifact ────────────────────────────────────────────────────────── + + +class TestLoadArtifact: + def test_parse_failure_raises_validation_error(self, store: ArtifactStore) -> None: + with patch( + "dbt_mcp.dbt_admin.run_artifacts.store.ARTIFACT_PARSERS" + ) as mock_parsers: + mock_parsers.__getitem__.return_value = MagicMock( + side_effect=ValueError("bad schema") + ) + with pytest.raises(ArtifactValidationError, match="Parsing failed"): + store.load_artifact( + run_id=1, artifact_type=ArtifactType.RUN_RESULTS, raw_data={} + ) + + def test_deferred_indexing_and_build_all(self, store: ArtifactStore) -> None: + fake_rows = [_invocation(0)] + with ( + patch( + "dbt_mcp.dbt_admin.run_artifacts.store.ARTIFACT_PARSERS" + ) as mock_parsers, + patch( + "dbt_mcp.dbt_admin.run_artifacts.store.ARTIFACT_EXTRACTORS" + ) as mock_extractors, + ): + mock_parsers.__getitem__.return_value = MagicMock(return_value={}) + mock_extractors.__getitem__.return_value = MagicMock( + return_value={"invocations": fake_rows} + ) + store.load_artifact( + run_id=1, + artifact_type=ArtifactType.RUN_RESULTS, + raw_data={}, + reindex=False, + ) + + assert "invocations" in store._pending_index_tables + store.build_all_indexes() + assert not store._pending_index_tables + + +# ── Merge node columns ───────────────────────────────────────────────────── + + +class TestMergeNodeColumns: + def _seed(self, store: ArtifactStore, run_id: int = 1) -> None: + store._ensure_tables_created() + store._insert_rows(NODE_COLUMNS, [_node_column(0, col="amount")], run_id=run_id) + + def test_updates_existing_row_with_catalog_data(self, store: ArtifactStore) -> None: + self._seed(store) + store._merge_node_columns( + [("model.pkg.x", "amount", 0, "INTEGER", "dollar amount")], run_id=1 + ) + + row = store.conn.execute( + "SELECT catalog_type, catalog_comment FROM node_columns WHERE column_name = 'amount'" + ).fetchone() + assert row == ("INTEGER", "dollar amount") + + def test_inserts_catalog_only_column(self, store: ArtifactStore) -> None: + self._seed(store) + store._merge_node_columns( + [("model.pkg.x", "revenue", 1, "FLOAT", None)], run_id=1 + ) + + row = store.conn.execute( + "SELECT catalog_type FROM node_columns WHERE column_name = 'revenue'" + ).fetchone() + assert row == ("FLOAT",) + + def test_scoped_to_run_id(self, store: ArtifactStore) -> None: + self._seed(store, run_id=1) + # Merge against run_id=2 — run_id=1 row must stay untouched + store._merge_node_columns( + [("model.pkg.x", "amount", 0, "TEXT", "overwritten")], run_id=2 + ) + + row = store.conn.execute( + "SELECT catalog_type FROM node_columns WHERE column_name = 'amount' AND run_id = 1" + ).fetchone() + assert row is not None + assert row[0] is None + + +# ── Full pipeline integration (no mocking) ──────────────────────────────── +# +# Minimal raw dicts trigger the parser's fallback path (strict validation fails, +# raw dict returned as-is) so extractors run against real — just sparse — data. + +_RUN_RESULTS_RAW: dict = { + "metadata": { + "invocation_id": "inv-rr-001", + "dbt_version": "1.9.0", + "generated_at": "2024-01-01T00:00:00Z", + }, + "args": {"which": "run", "select": "my_model"}, + "elapsed_time": 3.5, + "results": [ + { + "unique_id": "model.pkg.my_model", + "status": "success", + "execution_time": 1.2, + "thread_id": "thread-1", + "message": "1 of 1 OK", + } + ], +} + +_MANIFEST_RAW: dict = { + "metadata": { + "invocation_id": "inv-manifest-001", + "dbt_version": "1.9.0", + "generated_at": "2024-01-01T00:00:00Z", + }, + "nodes": { + "model.pkg.my_model": { + "unique_id": "model.pkg.my_model", + "name": "my_model", + "resource_type": "model", + "package_name": "pkg", + "description": "A test model for artifact loading", + "columns": { + "id": { + "name": "id", + "description": "Primary key", + "data_type": "INTEGER", + }, + "name": {"name": "name", "description": "User name"}, + }, + "depends_on": {"nodes": [], "macros": []}, + } + }, + "sources": {}, + "exposures": {}, + "metrics": {}, + "groups": {}, + "macros": {}, +} + +_CATALOG_RAW: dict = { + "metadata": {"generated_at": "2024-01-01T00:00:00Z"}, + "nodes": { + "model.pkg.my_model": { + "metadata": { + "type": "table", + "database": "my_db", + "schema": "public", + "name": "my_model", + }, + "stats": {}, + "columns": { + "id": {"name": "id", "index": 1, "type": "INTEGER", "comment": "PK"}, + "name": {"name": "name", "index": 2, "type": "VARCHAR"}, + }, + } + }, + "sources": {}, +} + +_SOURCES_RAW: dict = { + "metadata": { + "invocation_id": "inv-src-001", + "dbt_version": "1.9.0", + "generated_at": "2024-01-01T00:00:00Z", + }, + "elapsed_time": 1.0, + "results": [ + { + "unique_id": "source.pkg.raw.orders", + "status": "pass", + "execution_time": 0.5, + "thread_id": "thread-1", + "max_loaded_at": "2024-01-01T00:00:00Z", + "snapshotted_at": "2024-01-01T01:00:00Z", + "criteria": { + "warn_after": {"count": 24, "period": "hour"}, + "error_after": {}, + }, + } + ], +} + + +class TestLoadArtifactIntegration: + def test_run_results_loads_invocation_and_results( + self, store: ArtifactStore + ) -> None: + store.load_artifact( + run_id=1, artifact_type=ArtifactType.RUN_RESULTS, raw_data=_RUN_RESULTS_RAW + ) + + invocations = store.query("SELECT invocation_id, command FROM invocations") + assert invocations == [{"invocation_id": "inv-rr-001", "command": "run"}] + + results = store.query("SELECT unique_id, status FROM run_results") + assert results == [{"unique_id": "model.pkg.my_model", "status": "success"}] + + def test_manifest_loads_nodes_and_columns(self, store: ArtifactStore) -> None: + store.load_artifact( + run_id=1, artifact_type=ArtifactType.MANIFEST, raw_data=_MANIFEST_RAW + ) + + nodes = store.query("SELECT unique_id, description FROM nodes") + assert nodes[0]["description"] == "A test model for artifact loading" + + columns = store.query( + "SELECT column_name FROM node_columns ORDER BY column_name" + ) + assert [r["column_name"] for r in columns] == ["id", "name"] + + def test_sources_loads_freshness_rows(self, store: ArtifactStore) -> None: + store.load_artifact( + run_id=1, artifact_type=ArtifactType.SOURCES, raw_data=_SOURCES_RAW + ) + + rows = store.query( + "SELECT unique_id, status, warn_after_count FROM source_freshness" + ) + assert rows == [ + { + "unique_id": "source.pkg.raw.orders", + "status": "pass", + "warn_after_count": 24, + } + ] + + def test_catalog_after_manifest_merges_catalog_types( + self, store: ArtifactStore + ) -> None: + store.load_artifact( + run_id=1, artifact_type=ArtifactType.MANIFEST, raw_data=_MANIFEST_RAW + ) + assert ( + store.query( + "SELECT catalog_type FROM node_columns WHERE column_name = 'id'" + )[0]["catalog_type"] + is None + ) + + store.load_artifact( + run_id=1, artifact_type=ArtifactType.CATALOG, raw_data=_CATALOG_RAW + ) + assert ( + store.query( + "SELECT catalog_type FROM node_columns WHERE column_name = 'id'" + )[0]["catalog_type"] + == "INTEGER" + ) + + def test_fts_search_works_after_load(self, store: ArtifactStore) -> None: + store.load_artifact( + run_id=1, artifact_type=ArtifactType.MANIFEST, raw_data=_MANIFEST_RAW + ) + results = store.search(table_name="nodes", query_text="test model") + assert len(results) >= 1 + assert results[0]["unique_id"] == "model.pkg.my_model" + + def test_multiple_artifact_types_coexist(self, store: ArtifactStore) -> None: + store.load_artifact( + run_id=1, artifact_type=ArtifactType.RUN_RESULTS, raw_data=_RUN_RESULTS_RAW + ) + store.load_artifact( + run_id=1, artifact_type=ArtifactType.MANIFEST, raw_data=_MANIFEST_RAW + ) + store.load_artifact( + run_id=1, artifact_type=ArtifactType.SOURCES, raw_data=_SOURCES_RAW + ) + + loaded = { + t["table_name"] for t in store.list_tables() if t["status"] == "loaded" + } + assert { + "invocations", + "run_results", + "nodes", + "node_columns", + "source_freshness", + }.issubset(loaded) diff --git a/uv.lock b/uv.lock index 3a2a0ff85..2011bd2ff 100644 --- a/uv.lock +++ b/uv.lock @@ -7,7 +7,7 @@ resolution-markers = [ ] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer = "2026-05-06T22:00:02.961844Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -284,6 +284,7 @@ dependencies = [ { name = "dbt-protos" }, { name = "dbt-sl-sdk", extra = ["sync"] }, { name = "dbtlabs-vortex" }, + { name = "duckdb" }, { name = "fastapi" }, { name = "filelock" }, { name = "httpx" }, @@ -315,6 +316,7 @@ requires-dist = [ { name = "dbt-protos", specifier = "~=1.0.431" }, { name = "dbt-sl-sdk", extras = ["sync"], specifier = "~=0.13.2" }, { name = "dbtlabs-vortex", specifier = "~=0.2.0" }, + { name = "duckdb", specifier = ">=1.5.2" }, { name = "fastapi", specifier = "~=0.128.0" }, { name = "filelock", specifier = "~=3.20.3" }, { name = "httpx", specifier = "~=0.28.1" }, @@ -405,6 +407,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "duckdb" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/66/744b4931b799a42f8cb9bc7a6f169e7b8e51195b62b246db407fd90bf15f/duckdb-1.5.2.tar.gz", hash = "sha256:638da0d5102b6cb6f7d47f83d0600708ac1d3cb46c5e9aaabc845f9ba4d69246", size = 18017166, upload-time = "2026-04-13T11:30:09.065Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/de/ebe66bbe78125fc610f4fd415447a65349d94245950f3b3dfb31d028af02/duckdb-1.5.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e6495b00cad16888384119842797c49316a96ae1cb132bb03856d980d95afee1", size = 30064950, upload-time = "2026-04-13T11:29:11.468Z" }, + { url = "https://files.pythonhosted.org/packages/2d/8a/3e25b5d03bcf1fb99d189912f8ce92b1db4f9c8778e1b1f55745973a855a/duckdb-1.5.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d72b8856b1839d35648f38301b058f6232f4d36b463fe4dc8f4d3fdff2df1a2e", size = 15969113, upload-time = "2026-04-13T11:29:14.139Z" }, + { url = "https://files.pythonhosted.org/packages/19/bb/58001f0815002b1a93431bf907f77854085c7d049b83d521814a07b9db0b/duckdb-1.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2a1de4f4d454b8c97aec546c82003fc834d3422ce4bc6a19902f3462ef293bed", size = 14224774, upload-time = "2026-04-13T11:29:16.758Z" }, + { url = "https://files.pythonhosted.org/packages/d3/2f/a7f0de9509d1cef35608aeb382919041cdd70f58c173865c3da6a0d87979/duckdb-1.5.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce0b8141a10d37ecef729c45bc41d334854013f4389f1488bd6035c5579aaac1", size = 19313510, upload-time = "2026-04-13T11:29:19.574Z" }, + { url = "https://files.pythonhosted.org/packages/26/78/eb1e064ea8b9df3b87b167bfd7a407b2f615a4291e06cba756727adfa06c/duckdb-1.5.2-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c99ef73a277c8921bc0a1f16dee38d924484251d9cfd20951748c20fcd5ed855", size = 21429692, upload-time = "2026-04-13T11:29:22.575Z" }, + { url = "https://files.pythonhosted.org/packages/5b/12/05b0c47d14839925c5e35b79081d918ca82e3f236bb724a6f58409dd5291/duckdb-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:8d599758b4e48bf12e18c9b960cf491d219f0c4972d19a45489c05cc5ab36f83", size = 13107594, upload-time = "2026-04-13T11:29:25.43Z" }, + { url = "https://files.pythonhosted.org/packages/0b/2c/80558a82b236e044330e84a154b96aacddb343316b479f3d49be03ea11cb/duckdb-1.5.2-cp312-cp312-win_arm64.whl", hash = "sha256:fc85a5dbcbe6eccac1113c72370d1d3aacfdd49198d63950bdf7d8638a307f00", size = 13927537, upload-time = "2026-04-13T11:29:27.842Z" }, + { url = "https://files.pythonhosted.org/packages/98/f2/e3d742808f138d374be4bb516fade3d1f33749b813650810ab7885cdc363/duckdb-1.5.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4420b3f47027a7849d0e1815532007f377fa95ee5810b47ea717d35525c12f79", size = 30064879, upload-time = "2026-04-13T11:29:30.763Z" }, + { url = "https://files.pythonhosted.org/packages/72/0d/f3dc1cf97e1267ca15e4307d456f96ce583961f0703fd75e62b2ad8d64fa/duckdb-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:bb42e6ed543902e14eae647850da24103a89f0bc2587dec5601b1c1f213bd2ed", size = 15969327, upload-time = "2026-04-13T11:29:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e0/d5418def53ae4e05a63075705ff44ed5af5a1a5932627eb2b600c5df1c93/duckdb-1.5.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:98c0535cd6d901f61a5ea3c2e26a1fd28482953d794deb183daf568e3aa5dda6", size = 14225107, upload-time = "2026-04-13T11:29:35.882Z" }, + { url = "https://files.pythonhosted.org/packages/16/a7/15aaa59dbecc35e9711980fcdbf525b32a52470b32d18ef678193a146213/duckdb-1.5.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:486c862bf7f163c0110b6d85b3e5c031d224a671cca468f12ebb1d3a348f6b39", size = 19313433, upload-time = "2026-04-13T11:29:38.367Z" }, + { url = "https://files.pythonhosted.org/packages/bd/21/d903cc63a5140c822b7b62b373a87dc557e60c29b321dfb435061c5e67cf/duckdb-1.5.2-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70631c847ca918ee710ec874241b00cf9d2e5be90762cbb2a0389f17823c08f7", size = 21429837, upload-time = "2026-04-13T11:29:41.135Z" }, + { url = "https://files.pythonhosted.org/packages/e3/0a/b770d1f60c70597302130d6247f418549b7094251a02348fbaf1c7e147ae/duckdb-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:52a21823f3fbb52f0f0e5425e20b07391ad882464b955879499b5ff0b45a376b", size = 13107699, upload-time = "2026-04-13T11:29:43.905Z" }, + { url = "https://files.pythonhosted.org/packages/d9/cf/e200fe431d700962d1a908d2ce89f53ccee1cc8db260174ae663ba09686b/duckdb-1.5.2-cp313-cp313-win_arm64.whl", hash = "sha256:411ad438bd4140f189a10e7f515781335962c5d18bd07837dc6d202e3985253d", size = 13927646, upload-time = "2026-04-13T11:29:46.598Z" }, +] + [[package]] name = "fastapi" version = "0.128.0" From a8c801ca8ef2822cb0bda384e0ea1730067bcc75 Mon Sep 17 00:00:00 2001 From: Jairus Martinez <114552516+jairus-m@users.noreply.github.com> Date: Wed, 20 May 2026 20:36:15 -0700 Subject: [PATCH 2/4] Address copilot review --- .../dbt_admin/run_artifacts/extractors.py | 2 +- src/dbt_mcp/dbt_admin/run_artifacts/store.py | 28 ++++++++++++++++--- .../dbt_admin/run_artifacts/test_store.py | 10 ++++--- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py b/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py index 3bd8189ff..f29de8835 100644 --- a/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py +++ b/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py @@ -409,7 +409,7 @@ def extract_from_catalog(data: dict[str, Any]) -> dict[str, list[tuple]]: col.get("name") or col_name, col.get("index"), col.get("type"), - col.get("comment") or "", + col.get("comment"), ) ) diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/store.py b/src/dbt_mcp/dbt_admin/run_artifacts/store.py index 9aa07c24b..debef3cb8 100644 --- a/src/dbt_mcp/dbt_admin/run_artifacts/store.py +++ b/src/dbt_mcp/dbt_admin/run_artifacts/store.py @@ -5,6 +5,7 @@ """ import logging +import re import time from typing import Any @@ -19,6 +20,7 @@ from dbt_mcp.errors.artifact_search import ( ArtifactNotLoadedError, ArtifactQueryError, + ArtifactSearchError, ArtifactValidationError, ) @@ -56,8 +58,16 @@ class ArtifactStore: def __init__(self) -> None: self.conn = duckdb.connect() - self.conn.execute("INSTALL fts;") - self.conn.execute("LOAD fts;") + try: + self.conn.execute("LOAD fts;") + except duckdb.Error: + try: + self.conn.execute("INSTALL fts;") + self.conn.execute("LOAD fts;") + except duckdb.Error as e: + raise ArtifactSearchError( + f"Failed to load the DuckDB FTS extension: {e}" + ) from e self._loaded_tables: set[str] = set() self._tables_created: bool = False self._pending_index_tables: set[str] = set() @@ -277,7 +287,7 @@ def _build_indexes(self, config: TableConfig) -> None: ) for col in config.index_columns: - idx_name = f"idx_{config.table_name[:4]}_{col}" + idx_name = f"idx_{config.table_name}_{col}" self.conn.execute( f"CREATE INDEX IF NOT EXISTS {idx_name} " f'ON {config.table_name}("{col}");' @@ -327,7 +337,10 @@ def describe_table(self, table_name: str) -> list[dict[str, str]]: def query(self, sql: str) -> list[dict[str, Any]]: """Execute a read-only SQL query. Results capped at 500 rows.""" - tokens = sql.strip().upper().split() + sanitized = _strip_sql_comments(sql) + if ";" in sanitized: + raise ArtifactQueryError("Multi-statement queries are not allowed.") + tokens = sanitized.strip().upper().split() for token in tokens: if token in READONLY_BLOCKED: raise ArtifactQueryError( @@ -391,6 +404,13 @@ def _validate_table_name(self, table_name: str) -> None: ) +def _strip_sql_comments(sql: str) -> str: + """Remove SQL block (/* */) and line (--) comments.""" + sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) + sql = re.sub(r"--[^\n]*", " ", sql) + return sql + + def _serialize(val: Any) -> Any: """Ensure values are JSON-serializable.""" if val is None or isinstance(val, (str, int, float, bool)): diff --git a/tests/unit/dbt_admin/run_artifacts/test_store.py b/tests/unit/dbt_admin/run_artifacts/test_store.py index 230f8ffcd..cc40f63a3 100644 --- a/tests/unit/dbt_admin/run_artifacts/test_store.py +++ b/tests/unit/dbt_admin/run_artifacts/test_store.py @@ -133,12 +133,14 @@ def test_mutating_keywords_are_blocked(self, store: ArtifactStore) -> None: with pytest.raises(ArtifactQueryError, match="Blocked keyword"): store.query(f"{keyword} something") - def test_mutating_keywords_blocked_anywhere_in_query( - self, store: ArtifactStore - ) -> None: - with pytest.raises(ArtifactQueryError, match="Blocked keyword"): + def test_multi_statement_query_blocked(self, store: ArtifactStore) -> None: + with pytest.raises(ArtifactQueryError, match="Multi-statement"): store.query("SELECT 1; DROP TABLE nodes") + def test_comment_bypass_blocked(self, store: ArtifactStore) -> None: + with pytest.raises(ArtifactQueryError, match="Multi-statement"): + store.query("SELECT 1;/**/DROP TABLE nodes") + def test_invalid_sql_raises(self, store: ArtifactStore) -> None: with pytest.raises(ArtifactQueryError, match="Query failed"): store.query("SELECT * FROM table_that_does_not_exist_xyz") From 1110bd48136bb6ca28ea81fb1474341267f498f2 Mon Sep 17 00:00:00 2001 From: Jairus Martinez <114552516+jairus-m@users.noreply.github.com> Date: Wed, 20 May 2026 20:47:03 -0700 Subject: [PATCH 3/4] Update docstrings per copilot --- src/dbt_mcp/dbt_admin/run_artifacts/extractors.py | 6 +++++- src/dbt_mcp/errors/artifact_search.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py b/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py index f29de8835..3cd2981a8 100644 --- a/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py +++ b/src/dbt_mcp/dbt_admin/run_artifacts/extractors.py @@ -25,7 +25,11 @@ def _json(data: Any) -> str: - """Serialize ``data`` to a JSON string; returns empty string for falsy values.""" + """Serialize ``data`` to a JSON string. + + Returns empty string for ``None``, empty strings, and empty collections. + Falsy scalars (``0``, ``False``) are serialized as their JSON representation. + """ if data is None: return "" if isinstance(data, str): diff --git a/src/dbt_mcp/errors/artifact_search.py b/src/dbt_mcp/errors/artifact_search.py index 59528acb5..72cc14dd1 100644 --- a/src/dbt_mcp/errors/artifact_search.py +++ b/src/dbt_mcp/errors/artifact_search.py @@ -18,4 +18,4 @@ class ArtifactValidationError(ArtifactSearchError): class ArtifactNotLoadedError(ToolCallError): - """Raised when querying before any artifacts have been loaded (client error).""" + """Raised when querying a table that has not been loaded or does not exist (client error).""" From 7556d23db71ec010b7293f758430cd3aa0c18aea Mon Sep 17 00:00:00 2001 From: Jairus Martinez <114552516+jairus-m@users.noreply.github.com> Date: Wed, 20 May 2026 20:50:10 -0700 Subject: [PATCH 4/4] Fix SQL query handling to allow trailing semicolons (bug from copilot fix) and improve error checking --- src/dbt_mcp/dbt_admin/run_artifacts/store.py | 19 +++++++++++++------ .../dbt_admin/run_artifacts/test_store.py | 5 +++++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/dbt_mcp/dbt_admin/run_artifacts/store.py b/src/dbt_mcp/dbt_admin/run_artifacts/store.py index debef3cb8..62b768bc5 100644 --- a/src/dbt_mcp/dbt_admin/run_artifacts/store.py +++ b/src/dbt_mcp/dbt_admin/run_artifacts/store.py @@ -337,10 +337,12 @@ def describe_table(self, table_name: str) -> list[dict[str, str]]: def query(self, sql: str) -> list[dict[str, Any]]: """Execute a read-only SQL query. Results capped at 500 rows.""" - sanitized = _strip_sql_comments(sql) + sanitized = _strip_sql_comments(sql).strip() + if sanitized.endswith(";"): + sanitized = sanitized[:-1] if ";" in sanitized: raise ArtifactQueryError("Multi-statement queries are not allowed.") - tokens = sanitized.strip().upper().split() + tokens = sanitized.upper().split() for token in tokens: if token in READONLY_BLOCKED: raise ArtifactQueryError( @@ -371,14 +373,19 @@ def search( fts_table = f"fts_main_{table_name}" sql = f""" - SELECT t.*, {fts_table}.match_bm25(t.id, ?) AS fts_score + WITH scored AS ( + SELECT id, {fts_table}.match_bm25(id, ?) AS fts_score + FROM "{table_name}" + ) + SELECT t.*, s.fts_score FROM "{table_name}" t - WHERE {fts_table}.match_bm25(t.id, ?) IS NOT NULL - ORDER BY fts_score DESC + JOIN scored s ON t.id = s.id + WHERE s.fts_score IS NOT NULL + ORDER BY s.fts_score DESC LIMIT ? """ try: - result = self.conn.execute(sql, [query_text, query_text, limit]) + result = self.conn.execute(sql, [query_text, limit]) columns = [desc[0] for desc in result.description] rows = result.fetchall() return [ diff --git a/tests/unit/dbt_admin/run_artifacts/test_store.py b/tests/unit/dbt_admin/run_artifacts/test_store.py index cc40f63a3..c4ab79dbe 100644 --- a/tests/unit/dbt_admin/run_artifacts/test_store.py +++ b/tests/unit/dbt_admin/run_artifacts/test_store.py @@ -141,6 +141,11 @@ def test_comment_bypass_blocked(self, store: ArtifactStore) -> None: with pytest.raises(ArtifactQueryError, match="Multi-statement"): store.query("SELECT 1;/**/DROP TABLE nodes") + def test_trailing_semicolon_is_allowed(self, store: ArtifactStore) -> None: + store._ensure_tables_created() + result = store.query("SELECT 1 AS n;") + assert result == [{"n": 1}] + def test_invalid_sql_raises(self, store: ArtifactStore) -> None: with pytest.raises(ArtifactQueryError, match="Query failed"): store.query("SELECT * FROM table_that_does_not_exist_xyz")