|
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import logging |
| 8 | +import re |
8 | 9 | import time |
9 | 10 | from typing import Any |
10 | 11 |
|
|
19 | 20 | from dbt_mcp.errors.artifact_search import ( |
20 | 21 | ArtifactNotLoadedError, |
21 | 22 | ArtifactQueryError, |
| 23 | + ArtifactSearchError, |
22 | 24 | ArtifactValidationError, |
23 | 25 | ) |
24 | 26 |
|
@@ -56,8 +58,16 @@ class ArtifactStore: |
56 | 58 |
|
57 | 59 | def __init__(self) -> None: |
58 | 60 | self.conn = duckdb.connect() |
59 | | - self.conn.execute("INSTALL fts;") |
60 | | - self.conn.execute("LOAD fts;") |
| 61 | + try: |
| 62 | + self.conn.execute("LOAD fts;") |
| 63 | + except duckdb.Error: |
| 64 | + try: |
| 65 | + self.conn.execute("INSTALL fts;") |
| 66 | + self.conn.execute("LOAD fts;") |
| 67 | + except duckdb.Error as e: |
| 68 | + raise ArtifactSearchError( |
| 69 | + f"Failed to load the DuckDB FTS extension: {e}" |
| 70 | + ) from e |
61 | 71 | self._loaded_tables: set[str] = set() |
62 | 72 | self._tables_created: bool = False |
63 | 73 | self._pending_index_tables: set[str] = set() |
@@ -277,7 +287,7 @@ def _build_indexes(self, config: TableConfig) -> None: |
277 | 287 | ) |
278 | 288 |
|
279 | 289 | for col in config.index_columns: |
280 | | - idx_name = f"idx_{config.table_name[:4]}_{col}" |
| 290 | + idx_name = f"idx_{config.table_name}_{col}" |
281 | 291 | self.conn.execute( |
282 | 292 | f"CREATE INDEX IF NOT EXISTS {idx_name} " |
283 | 293 | f'ON {config.table_name}("{col}");' |
@@ -327,7 +337,10 @@ def describe_table(self, table_name: str) -> list[dict[str, str]]: |
327 | 337 |
|
328 | 338 | def query(self, sql: str) -> list[dict[str, Any]]: |
329 | 339 | """Execute a read-only SQL query. Results capped at 500 rows.""" |
330 | | - tokens = sql.strip().upper().split() |
| 340 | + sanitized = _strip_sql_comments(sql) |
| 341 | + if ";" in sanitized: |
| 342 | + raise ArtifactQueryError("Multi-statement queries are not allowed.") |
| 343 | + tokens = sanitized.strip().upper().split() |
331 | 344 | for token in tokens: |
332 | 345 | if token in READONLY_BLOCKED: |
333 | 346 | raise ArtifactQueryError( |
@@ -391,6 +404,13 @@ def _validate_table_name(self, table_name: str) -> None: |
391 | 404 | ) |
392 | 405 |
|
393 | 406 |
|
| 407 | +def _strip_sql_comments(sql: str) -> str: |
| 408 | + """Remove SQL block (/* */) and line (--) comments.""" |
| 409 | + sql = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) |
| 410 | + sql = re.sub(r"--[^\n]*", " ", sql) |
| 411 | + return sql |
| 412 | + |
| 413 | + |
394 | 414 | def _serialize(val: Any) -> Any: |
395 | 415 | """Ensure values are JSON-serializable.""" |
396 | 416 | if val is None or isinstance(val, (str, int, float, bool)): |
|
0 commit comments