diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index 12d4a3ad..b0c3fd90 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -10,6 +10,7 @@ import pyarrow as pa import pyarrow.dataset as ds from pydantic import Field, PrivateAttr +from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy import text from airbyte_protocol.models import ConfiguredAirbyteCatalog @@ -145,6 +146,54 @@ def processor(self) -> SqlProcessorBase: """Return the SQL processor instance.""" return self._read_processor + def run_sql_query( + self, + sql_query: str, + *, + max_records: int | None = None, + ) -> list[dict[str, Any]]: + """Run a SQL query against the cache and return results as a list of dictionaries. + + This method is designed for single DML statements like SELECT, SHOW, or DESCRIBE. + For DDL statements or multiple statements, use the processor directly. + + Args: + sql_query: The SQL query to execute + max_records: Maximum number of records to return. If None, returns all records. + + Returns: + List of dictionaries representing the query results + """ + # Execute the SQL within a connection context to ensure the connection stays open + # while we fetch the results + sql_text = text(sql_query) if isinstance(sql_query, str) else sql_query + + with self.processor.get_sql_connection() as conn: + try: + result = conn.execute(sql_text) + except ( + sqlalchemy_exc.ProgrammingError, + sqlalchemy_exc.SQLAlchemyError, + ) as ex: + msg = f"Error when executing SQL:\n{sql_query}\n{type(ex).__name__}{ex!s}" + raise RuntimeError(msg) from ex + + # Convert the result to a list of dictionaries while connection is still open + if result.returns_rows: + # Get column names + columns = list(result.keys()) if result.keys() else [] + + # Fetch rows efficiently based on limit + if max_records is not None: + rows = result.fetchmany(max_records) + else: + rows = result.fetchall() + + return [dict(zip(columns, row, strict=True)) for row in rows] + + # For non-SELECT queries (INSERT, UPDATE, DELETE, etc.) + return [] + def get_record_processor( self, source_name: str, diff --git a/airbyte/mcp/_local_ops.py b/airbyte/mcp/_local_ops.py index f2cb06ae..2c1c8bf9 100644 --- a/airbyte/mcp/_local_ops.py +++ b/airbyte/mcp/_local_ops.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Annotated, Any from fastmcp import FastMCP -from pydantic import Field +from pydantic import BaseModel, Field from airbyte import get_source from airbyte.caches.util import get_default_cache @@ -19,6 +19,7 @@ if TYPE_CHECKING: + from airbyte.caches.duckdb import DuckDBCache from airbyte.sources.base import Source @@ -270,21 +271,145 @@ def sync_source_to_cache( cache=cache, streams=streams, ) + del cache # Ensure the cache is closed properly summary: str = f"Sync completed for '{source_connector_name}'!\n\n" summary += "Data written to default DuckDB cache\n" return summary +class CachedDatasetInfo(BaseModel): + """Class to hold information about a cached dataset.""" + + stream_name: str + """The name of the stream in the cache.""" + table_name: str + schema_name: str | None = None + + +def list_cached_streams() -> list[CachedDatasetInfo]: + """List all streams available in the default DuckDB cache.""" + cache: DuckDBCache = get_default_cache() + result = [ + CachedDatasetInfo( + stream_name=stream_name, + table_name=(cache.table_prefix or "") + stream_name, + schema_name=cache.schema_name, + ) + for stream_name in cache.streams + ] + del cache # Ensure the cache is closed properly + return result + + +def describe_default_cache() -> dict[str, Any]: + """Describe the currently configured default cache.""" + cache = get_default_cache() + return { + "cache_type": type(cache).__name__, + "cache_dir": str(cache.cache_dir), + "cache_db_path": str(Path(cache.db_path).absolute()), + "cached_streams": list(cache.streams.keys()), + } + + +def _is_safe_sql(sql_query: str) -> bool: + """Check if a SQL query is safe to execute. + + For security reasons, we only allow read-only operations like SELECT, DESCRIBE, and SHOW. + Multi-statement queries (containing semicolons) are also disallowed for security. + + Note: SQLAlchemy will also validate downstream, but this is a first-pass check. + + Args: + sql_query: The SQL query to check + + Returns: + True if the query is safe to execute, False otherwise + """ + # Remove leading/trailing whitespace and convert to uppercase for checking + normalized_query = sql_query.strip().upper() + + # Disallow multi-statement queries (containing semicolons) + # Note: We check the original query to catch semicolons anywhere, including in comments + if ";" in sql_query: + return False + + # List of allowed SQL statement prefixes (read-only operations) + allowed_prefixes = ( + "SELECT", + "DESCRIBE", + "DESC", # Short form of DESCRIBE + "SHOW", + "EXPLAIN", # Also safe - shows query execution plan + ) + + # Check if the query starts with any allowed prefix + return any(normalized_query.startswith(prefix) for prefix in allowed_prefixes) + + +def run_sql_query( + sql_query: Annotated[ + str, + Field(description="The SQL query to execute."), + ], + max_records: Annotated[ + int, + Field(description="Maximum number of records to return."), + ] = 1000, +) -> list[dict[str, Any]]: + """Run a SQL query against the default cache. + + The dialect of SQL should match the dialect of the default cache. + Use `describe_default_cache` to see the cache type. + + For DuckDB-type caches: + - Use `SHOW TABLES` to list all tables. + - Use `DESCRIBE ` to get the schema of a specific table + + For security reasons, only read-only operations are allowed: SELECT, DESCRIBE, SHOW, EXPLAIN. + """ + # Check if the query is safe to execute + if not _is_safe_sql(sql_query): + return [ + { + "ERROR": "Unsafe SQL query detected. Only read-only operations are allowed: " + "SELECT, DESCRIBE, SHOW, EXPLAIN", + "SQL_QUERY": sql_query, + } + ] + + cache: DuckDBCache = get_default_cache() + try: + return cache.run_sql_query( + sql_query, + max_records=max_records, + ) + except Exception as ex: + tb_str = traceback.format_exc() + return [ + { + "ERROR": f"Error running SQL query: {ex!r}, {ex!s}", + "TRACEBACK": tb_str, + "SQL_QUERY": sql_query, + } + ] + finally: + del cache # Ensure the cache is closed properly + + def register_local_ops_tools(app: FastMCP) -> None: """Register tools with the FastMCP app.""" app.tool(list_connector_config_secrets) for tool in ( - validate_connector_config, - list_source_streams, + describe_default_cache, get_source_stream_json_schema, + list_cached_streams, + list_source_streams, read_source_stream_records, + run_sql_query, sync_source_to_cache, + validate_connector_config, ): # Register each tool with the FastMCP app. app.tool( diff --git a/tests/integration_tests/test_duckdb_cache.py b/tests/integration_tests/test_duckdb_cache.py index 6a44dff9..3a29ed1c 100644 --- a/tests/integration_tests/test_duckdb_cache.py +++ b/tests/integration_tests/test_duckdb_cache.py @@ -19,6 +19,7 @@ from airbyte.caches.duckdb import DuckDBCache from airbyte.caches.util import new_local_cache + # Product count is always the same, regardless of faker scale. NUM_PRODUCTS = 100