Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:

- name: Run tests
run: |
uv run pytest tests
make test
30 changes: 30 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Makefile for dremio-mcp testing

.PHONY: help test test-unit test-e2e-separate clean

# Default target
help:
@echo "Available targets:"
@echo " test - Run unit tests first, then each e2e test separately"
@echo " test-unit - Run only unit tests (excluding e2e)"
@echo " test-e2e - Run each e2e test file separately"

# Main test target - runs unit tests first, then e2e separately
.PHONY: test test-unit test-e2e
test: test-unit test-e2e
@echo "All tests completed!"

# Run only unit tests (excluding e2e)
test-unit:
@echo "Running unit tests ..."
@uv run pytest tests --ignore=tests/e2e -v -x

# Run each e2e test file separately
test-e2e:
@echo "Running e2e tests ..."
@for file in tests/e2e/test_*.py; do \
if [ -f "$$file" ]; then \
echo "Running $$file..."; \
uv run pytest "$$file" -v || exit 1; \
fi; \
done
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"requests>=2.32.3",
"rich>=13.9.4",
"sqlglot>=26.23.0",
"starlette>=0.46.1",
"structlog>=25.1.0",
"typer>=0.15.2",
"uvicorn>=0.34.0",
Expand Down
22 changes: 0 additions & 22 deletions tests/config/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,6 @@
from dremioai.config.tools import ToolType


@pytest.fixture
def temp_config_dir():
"""Create a temporary directory for config files"""
with TemporaryDirectory() as temp_dir:
yield Path(temp_dir)


@pytest.fixture
def mock_config_dir(temp_config_dir):
"""Mock the home directory to use our temporary directory"""
with patch.object(Path, "home", return_value=temp_config_dir):
# Also patch XDG_CONFIG_HOME environment variable
old_env = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = str(temp_config_dir)
yield temp_config_dir
# Restore original environment
if old_env:
os.environ["XDG_CONFIG_HOME"] = old_env
else:
os.environ.pop("XDG_CONFIG_HOME", None)


def test_configure_with_no_file_works(mock_config_dir):
s = settings.instance()
assert settings.instance() is not None
Expand Down
151 changes: 151 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,29 @@
Global pytest fixtures for dremio-mcp tests.
"""
import os
import random
from typing import AsyncGenerator, NamedTuple

import pytest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest.mock import patch
from collections import OrderedDict

from dremioai.config import settings
from dremioai.config.tools import ToolType
from dremioai.servers.mcp import Transports, init

from mocks.http_mock import (
create_pytest_logging_server_fixture,
start_server,
ServerFixture,
LoggingServerFixture,
)
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
import contextlib
from dremioai.log import set_level


@pytest.fixture
Expand Down Expand Up @@ -69,3 +85,138 @@ def mock_settings_instance():
yield settings.instance()
finally:
settings._settings.set(old_settings)


@pytest.fixture
def temp_config_dir():
"""Create a temporary directory for config files"""
with TemporaryDirectory() as temp_dir:
yield Path(temp_dir)


@pytest.fixture
def mock_config_dir(temp_config_dir):
"""Mock the home directory to use our temporary directory"""
with patch.object(Path, "home", return_value=temp_config_dir):
# Also patch XDG_CONFIG_HOME environment variable
old_env = os.environ.get("XDG_CONFIG_HOME")
os.environ["XDG_CONFIG_HOME"] = str(temp_config_dir)
yield temp_config_dir
# Restore original environment
if old_env:
os.environ["XDG_CONFIG_HOME"] = old_env
else:
os.environ.pop("XDG_CONFIG_HOME", None)


def _create_logging_server(log_level="warning"):
# Mock data for HTTP endpoints that tools will call
mock_data = OrderedDict(
[
(r"/sql", "sql/job_submission.json"), # SQL query submission
(r"/job/test-job-12345$", "sql/job_status.json"), # Job status check
(r"/job/test-job-12345/results$", "sql/job_results.json"), # Job results
(r"/search", "search/search_results.json"), # Search endpoints
(r"/catalog/.*/wiki", "catalog/wiki.json"), # Wiki endpoints
(r"/catalog/.*/tags", "catalog/tags.json"), # Tags endpoints
(r"/catalog/.*/graph", "catalog/lineage.json"), # Lineage endpoints
(r"/catalog(/by-path)?", "catalog/table_schema.json"), # Schema endpoints
]
)

return create_pytest_logging_server_fixture(
mock_data=mock_data, port=8000, log_level=log_level
)


@pytest.fixture(scope="module")
def logging_level(request):
return "info"
if request.config.get_verbosity() > 2:
return "debug"
if request.config.get_verbosity() > 1:
return "info"
return "warning"


@pytest.fixture(scope="module")
def logging_server(logging_level):
server = _create_logging_server(logging_level)
try:
yield server
finally:
try:
server.close()
except:
from rich import traceback

traceback.print_exc()


class StreamableMcpServerFixture(NamedTuple):
mcp_server: ServerFixture
logging_server: LoggingServerFixture


@pytest.fixture
def http_streamable_mcp_server(logging_server, mock_config_dir, logging_level):
old = settings.instance()
sf = None
try:
settings.configure(force=True)
settings._settings.set(
settings.Settings.model_validate(
{
"dremio": {
"uri": logging_server.url,
"project_id": "test-project-id",
"pat": "test-pat",
"enable_search": True,
},
"tools": {"server_mode": ToolType.FOR_DATA_PATTERNS.name},
}
)
)
settings.write_settings()
port = random.randrange(9000, 12000)
set_level(logging_level.upper())
mcp_server = init(
transport=Transports.streamable_http,
port=port,
mode=settings.instance().tools.server_mode,
)

def should_exit(v: bool):
mcp_server.should_exit = v

server, stop_event = start_server(
mcp_server.run_streamable_http_async(), should_exit
)
sf = ServerFixture(f"http://127.0.0.1:{port}/mcp/", stop_event, server)
yield StreamableMcpServerFixture(sf, logging_server)
finally:
if sf is not None:
try:
sf.close()
except:
from rich import traceback

traceback.print_exc()
print(f"{sf} closed")
settings._settings.set(old)


@contextlib.asynccontextmanager
async def http_streamable_client_server(
sf: ServerFixture, token=None
) -> AsyncGenerator[ClientSession]:
headers = {"Authorization": f"Bearer {token}"} if token is not None else None
async with streamablehttp_client(url=sf.url, headers=headers) as (
read_stream,
write_stream,
gid,
):
print(f"Client connected to {sf.url}")
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
19 changes: 19 additions & 0 deletions tests/e2e/test_e2e_pat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
from mcp.types import CallToolResult
from conftest import http_streamable_client_server


@pytest.mark.asyncio
async def test_tool_pat(http_streamable_mcp_server):
async with http_streamable_client_server(
http_streamable_mcp_server.mcp_server,
token="my-token",
) as session:
result: CallToolResult = await session.call_tool(
"RunSqlQuery", {"s": "SELECT 1"}
)
assert result.structuredContent["result"]["result"][0]["test_column"] == 1
for le in http_streamable_mcp_server.logging_server.logs():
assert (
le.headers.get("authorization") == "Bearer my-token"
), f"{le} does not have the right auth header"
17 changes: 17 additions & 0 deletions tests/e2e/test_mcp_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
from conftest import http_streamable_client_server

from dremioai.tools.tools import get_tools
from dremioai.config import settings


@pytest.mark.asyncio
async def test_basic(http_streamable_mcp_server):
async with http_streamable_client_server(
http_streamable_mcp_server.mcp_server
) as session:
lts = await session.list_tools()
tr = {t.name for t in lts.tools}
assert tr == {
t.__name__ for t in get_tools(For=settings.instance().tools.server_mode)
}
Loading