Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
from pydantic import Field, PrivateAttr
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy import text

from airbyte_protocol.models import ConfiguredAirbyteCatalog
Expand Down Expand Up @@ -145,6 +146,54 @@ def processor(self) -> SqlProcessorBase:
"""Return the SQL processor instance."""
return self._read_processor

def run_sql_query(
self,
sql_query: str,
*,
max_records: int | None = None,
) -> list[dict[str, Any]]:
"""Run a SQL query against the cache and return results as a list of dictionaries.

This method is designed for single DML statements like SELECT, SHOW, or DESCRIBE.
For DDL statements or multiple statements, use the processor directly.

Args:
sql_query: The SQL query to execute
max_records: Maximum number of records to return. If None, returns all records.

Returns:
List of dictionaries representing the query results
"""
# Execute the SQL within a connection context to ensure the connection stays open
# while we fetch the results
sql_text = text(sql_query) if isinstance(sql_query, str) else sql_query

with self.processor.get_sql_connection() as conn:
try:
result = conn.execute(sql_text)
except (
sqlalchemy_exc.ProgrammingError,
sqlalchemy_exc.SQLAlchemyError,
) as ex:
msg = f"Error when executing SQL:\n{sql_query}\n{type(ex).__name__}{ex!s}"
raise RuntimeError(msg) from ex

# Convert the result to a list of dictionaries while connection is still open
if result.returns_rows:
# Get column names
columns = list(result.keys()) if result.keys() else []

# Fetch rows efficiently based on limit
if max_records is not None:
rows = result.fetchmany(max_records)
else:
rows = result.fetchall()

return [dict(zip(columns, row, strict=True)) for row in rows]

# For non-SELECT queries (INSERT, UPDATE, DELETE, etc.)
return []

def get_record_processor(
self,
source_name: str,
Expand Down
131 changes: 128 additions & 3 deletions airbyte/mcp/_local_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Annotated, Any

from fastmcp import FastMCP
from pydantic import Field
from pydantic import BaseModel, Field

from airbyte import get_source
from airbyte.caches.util import get_default_cache
Expand All @@ -19,6 +19,7 @@


if TYPE_CHECKING:
from airbyte.caches.duckdb import DuckDBCache
from airbyte.sources.base import Source


Expand Down Expand Up @@ -270,21 +271,145 @@ def sync_source_to_cache(
cache=cache,
streams=streams,
)
del cache # Ensure the cache is closed properly

summary: str = f"Sync completed for '{source_connector_name}'!\n\n"
summary += "Data written to default DuckDB cache\n"
return summary


class CachedDatasetInfo(BaseModel):
"""Class to hold information about a cached dataset."""

stream_name: str
"""The name of the stream in the cache."""
table_name: str
schema_name: str | None = None


def list_cached_streams() -> list[CachedDatasetInfo]:
"""List all streams available in the default DuckDB cache."""
cache: DuckDBCache = get_default_cache()
result = [
CachedDatasetInfo(
stream_name=stream_name,
table_name=(cache.table_prefix or "") + stream_name,
schema_name=cache.schema_name,
)
for stream_name in cache.streams
]
del cache # Ensure the cache is closed properly
return result


def describe_default_cache() -> dict[str, Any]:
"""Describe the currently configured default cache."""
cache = get_default_cache()
return {
"cache_type": type(cache).__name__,
"cache_dir": str(cache.cache_dir),
"cache_db_path": str(Path(cache.db_path).absolute()),
"cached_streams": list(cache.streams.keys()),
}


def _is_safe_sql(sql_query: str) -> bool:
"""Check if a SQL query is safe to execute.

For security reasons, we only allow read-only operations like SELECT, DESCRIBE, and SHOW.
Multi-statement queries (containing semicolons) are also disallowed for security.

Note: SQLAlchemy will also validate downstream, but this is a first-pass check.

Args:
sql_query: The SQL query to check

Returns:
True if the query is safe to execute, False otherwise
"""
# Remove leading/trailing whitespace and convert to uppercase for checking
normalized_query = sql_query.strip().upper()

# Disallow multi-statement queries (containing semicolons)
# Note: We check the original query to catch semicolons anywhere, including in comments
if ";" in sql_query:
return False

# List of allowed SQL statement prefixes (read-only operations)
allowed_prefixes = (
"SELECT",
"DESCRIBE",
"DESC", # Short form of DESCRIBE
"SHOW",
"EXPLAIN", # Also safe - shows query execution plan
)

# Check if the query starts with any allowed prefix
return any(normalized_query.startswith(prefix) for prefix in allowed_prefixes)


def run_sql_query(
sql_query: Annotated[
str,
Field(description="The SQL query to execute."),
],
max_records: Annotated[
int,
Field(description="Maximum number of records to return."),
] = 1000,
) -> list[dict[str, Any]]:
"""Run a SQL query against the default cache.

The dialect of SQL should match the dialect of the default cache.
Use `describe_default_cache` to see the cache type.

For DuckDB-type caches:
- Use `SHOW TABLES` to list all tables.
- Use `DESCRIBE <table_name>` to get the schema of a specific table

For security reasons, only read-only operations are allowed: SELECT, DESCRIBE, SHOW, EXPLAIN.
"""
# Check if the query is safe to execute
if not _is_safe_sql(sql_query):
return [
{
"ERROR": "Unsafe SQL query detected. Only read-only operations are allowed: "
"SELECT, DESCRIBE, SHOW, EXPLAIN",
"SQL_QUERY": sql_query,
}
]

cache: DuckDBCache = get_default_cache()
try:
return cache.run_sql_query(
sql_query,
max_records=max_records,
)
except Exception as ex:
tb_str = traceback.format_exc()
return [
{
"ERROR": f"Error running SQL query: {ex!r}, {ex!s}",
"TRACEBACK": tb_str,
"SQL_QUERY": sql_query,
}
]
finally:
del cache # Ensure the cache is closed properly


def register_local_ops_tools(app: FastMCP) -> None:
"""Register tools with the FastMCP app."""
app.tool(list_connector_config_secrets)
for tool in (
validate_connector_config,
list_source_streams,
describe_default_cache,
get_source_stream_json_schema,
list_cached_streams,
list_source_streams,
read_source_stream_records,
run_sql_query,
sync_source_to_cache,
validate_connector_config,
):
# Register each tool with the FastMCP app.
app.tool(
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/test_duckdb_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from airbyte.caches.duckdb import DuckDBCache
from airbyte.caches.util import new_local_cache


# Product count is always the same, regardless of faker scale.
NUM_PRODUCTS = 100

Expand Down