|
18 | 18 | Global pytest fixtures for dremio-mcp tests. |
19 | 19 | """ |
20 | 20 | import os |
| 21 | +import random |
| 22 | +from typing import AsyncGenerator, NamedTuple |
| 23 | + |
21 | 24 | import pytest |
22 | 25 | from pathlib import Path |
23 | 26 | from tempfile import TemporaryDirectory |
24 | 27 | from unittest.mock import patch |
| 28 | +from collections import OrderedDict |
25 | 29 |
|
26 | 30 | from dremioai.config import settings |
27 | 31 | from dremioai.config.tools import ToolType |
| 32 | +from dremioai.servers.mcp import Transports, init |
| 33 | + |
| 34 | +from mocks.http_mock import ( |
| 35 | + create_pytest_logging_server_fixture, |
| 36 | + start_server, |
| 37 | + ServerFixture, |
| 38 | + LoggingServerFixture, |
| 39 | +) |
| 40 | +from mcp import ClientSession |
| 41 | +from mcp.client.streamable_http import streamablehttp_client |
| 42 | +import contextlib |
| 43 | +from dremioai.log import set_level |
28 | 44 |
|
29 | 45 |
|
30 | 46 | @pytest.fixture |
@@ -69,3 +85,138 @@ def mock_settings_instance(): |
69 | 85 | yield settings.instance() |
70 | 86 | finally: |
71 | 87 | settings._settings.set(old_settings) |
| 88 | + |
| 89 | + |
| 90 | +@pytest.fixture |
| 91 | +def temp_config_dir(): |
| 92 | + """Create a temporary directory for config files""" |
| 93 | + with TemporaryDirectory() as temp_dir: |
| 94 | + yield Path(temp_dir) |
| 95 | + |
| 96 | + |
| 97 | +@pytest.fixture |
| 98 | +def mock_config_dir(temp_config_dir): |
| 99 | + """Mock the home directory to use our temporary directory""" |
| 100 | + with patch.object(Path, "home", return_value=temp_config_dir): |
| 101 | + # Also patch XDG_CONFIG_HOME environment variable |
| 102 | + old_env = os.environ.get("XDG_CONFIG_HOME") |
| 103 | + os.environ["XDG_CONFIG_HOME"] = str(temp_config_dir) |
| 104 | + yield temp_config_dir |
| 105 | + # Restore original environment |
| 106 | + if old_env: |
| 107 | + os.environ["XDG_CONFIG_HOME"] = old_env |
| 108 | + else: |
| 109 | + os.environ.pop("XDG_CONFIG_HOME", None) |
| 110 | + |
| 111 | + |
| 112 | +def _create_logging_server(log_level="warning"): |
| 113 | + # Mock data for HTTP endpoints that tools will call |
| 114 | + mock_data = OrderedDict( |
| 115 | + [ |
| 116 | + (r"/sql", "sql/job_submission.json"), # SQL query submission |
| 117 | + (r"/job/test-job-12345$", "sql/job_status.json"), # Job status check |
| 118 | + (r"/job/test-job-12345/results$", "sql/job_results.json"), # Job results |
| 119 | + (r"/search", "search/search_results.json"), # Search endpoints |
| 120 | + (r"/catalog/.*/wiki", "catalog/wiki.json"), # Wiki endpoints |
| 121 | + (r"/catalog/.*/tags", "catalog/tags.json"), # Tags endpoints |
| 122 | + (r"/catalog/.*/graph", "catalog/lineage.json"), # Lineage endpoints |
| 123 | + (r"/catalog(/by-path)?", "catalog/table_schema.json"), # Schema endpoints |
| 124 | + ] |
| 125 | + ) |
| 126 | + |
| 127 | + return create_pytest_logging_server_fixture( |
| 128 | + mock_data=mock_data, port=8000, log_level=log_level |
| 129 | + ) |
| 130 | + |
| 131 | + |
| 132 | +@pytest.fixture(scope="module") |
| 133 | +def logging_level(request): |
| 134 | + return "info" |
| 135 | + if request.config.get_verbosity() > 2: |
| 136 | + return "debug" |
| 137 | + if request.config.get_verbosity() > 1: |
| 138 | + return "info" |
| 139 | + return "warning" |
| 140 | + |
| 141 | + |
| 142 | +@pytest.fixture(scope="module") |
| 143 | +def logging_server(logging_level): |
| 144 | + server = _create_logging_server(logging_level) |
| 145 | + try: |
| 146 | + yield server |
| 147 | + finally: |
| 148 | + try: |
| 149 | + server.close() |
| 150 | + except: |
| 151 | + from rich import traceback |
| 152 | + |
| 153 | + traceback.print_exc() |
| 154 | + |
| 155 | + |
| 156 | +class StreamableMcpServerFixture(NamedTuple): |
| 157 | + mcp_server: ServerFixture |
| 158 | + logging_server: LoggingServerFixture |
| 159 | + |
| 160 | + |
| 161 | +@pytest.fixture |
| 162 | +def http_streamable_mcp_server(logging_server, mock_config_dir, logging_level): |
| 163 | + old = settings.instance() |
| 164 | + sf = None |
| 165 | + try: |
| 166 | + settings.configure(force=True) |
| 167 | + settings._settings.set( |
| 168 | + settings.Settings.model_validate( |
| 169 | + { |
| 170 | + "dremio": { |
| 171 | + "uri": logging_server.url, |
| 172 | + "project_id": "test-project-id", |
| 173 | + "pat": "test-pat", |
| 174 | + "enable_search": True, |
| 175 | + }, |
| 176 | + "tools": {"server_mode": ToolType.FOR_DATA_PATTERNS.name}, |
| 177 | + } |
| 178 | + ) |
| 179 | + ) |
| 180 | + settings.write_settings() |
| 181 | + port = random.randrange(9000, 12000) |
| 182 | + set_level(logging_level.upper()) |
| 183 | + mcp_server = init( |
| 184 | + transport=Transports.streamable_http, |
| 185 | + port=port, |
| 186 | + mode=settings.instance().tools.server_mode, |
| 187 | + ) |
| 188 | + |
| 189 | + def should_exit(v: bool): |
| 190 | + mcp_server.should_exit = v |
| 191 | + |
| 192 | + server, stop_event = start_server( |
| 193 | + mcp_server.run_streamable_http_async(), should_exit |
| 194 | + ) |
| 195 | + sf = ServerFixture(f"http://127.0.0.1:{port}/mcp/", stop_event, server) |
| 196 | + yield StreamableMcpServerFixture(sf, logging_server) |
| 197 | + finally: |
| 198 | + if sf is not None: |
| 199 | + try: |
| 200 | + sf.close() |
| 201 | + except: |
| 202 | + from rich import traceback |
| 203 | + |
| 204 | + traceback.print_exc() |
| 205 | + print(f"{sf} closed") |
| 206 | + settings._settings.set(old) |
| 207 | + |
| 208 | + |
| 209 | +@contextlib.asynccontextmanager |
| 210 | +async def http_streamable_client_server( |
| 211 | + sf: ServerFixture, token=None |
| 212 | +) -> AsyncGenerator[ClientSession]: |
| 213 | + headers = {"Authorization": f"Bearer {token}"} if token is not None else None |
| 214 | + async with streamablehttp_client(url=sf.url, headers=headers) as ( |
| 215 | + read_stream, |
| 216 | + write_stream, |
| 217 | + gid, |
| 218 | + ): |
| 219 | + print(f"Client connected to {sf.url}") |
| 220 | + async with ClientSession(read_stream, write_stream) as session: |
| 221 | + await session.initialize() |
| 222 | + yield session |
0 commit comments