|
8 | 8 | from __future__ import annotations |
9 | 9 |
|
10 | 10 | import os |
| 11 | +from contextlib import asynccontextmanager |
11 | 12 |
|
12 | 13 | import pytest |
13 | 14 | from conftest import _skip_if_not_found |
14 | 15 | from mcp.shared.exceptions import McpError |
15 | 16 | from mcp.types import CallToolResult |
16 | 17 |
|
| 18 | + |
| 19 | +@asynccontextmanager |
| 20 | +async def raw_mcp_session(url, workspace_client): |
| 21 | + """Create a raw MCP ClientSession using streamable_http_client with Databricks OAuth.""" |
| 22 | + import httpx |
| 23 | + from mcp import ClientSession |
| 24 | + from mcp.client.streamable_http import streamable_http_client |
| 25 | + |
| 26 | + from databricks_mcp import DatabricksOAuthClientProvider |
| 27 | + |
| 28 | + async with httpx.AsyncClient( |
| 29 | + auth=DatabricksOAuthClientProvider(workspace_client), |
| 30 | + follow_redirects=True, |
| 31 | + timeout=httpx.Timeout(120.0, read=120.0), |
| 32 | + ) as http_client: |
| 33 | + async with streamable_http_client(url, http_client=http_client) as ( |
| 34 | + read_stream, |
| 35 | + write_stream, |
| 36 | + _, |
| 37 | + ): |
| 38 | + async with ClientSession(read_stream, write_stream) as session: |
| 39 | + await session.initialize() |
| 40 | + yield session |
| 41 | + |
| 42 | + |
17 | 43 | from databricks_mcp import DatabricksMCPClient |
18 | 44 |
|
19 | 45 | pytestmark = pytest.mark.skipif( |
@@ -126,6 +152,94 @@ def test_call_tool_returns_result_with_content(self, cached_genie_call_result): |
126 | 152 | assert len(cached_genie_call_result.content) > 0 |
127 | 153 |
|
128 | 154 |
|
| 155 | +# ============================================================================= |
| 156 | +# DBSQL |
| 157 | +# ============================================================================= |
| 158 | + |
| 159 | + |
| 160 | +@pytest.mark.integration |
| 161 | +class TestMCPClientDBSQL: |
| 162 | + """Verify list_tools() and call_tool() against a live DBSQL MCP server.""" |
| 163 | + |
| 164 | + def test_list_tools_returns_expected_tools(self, cached_dbsql_tools_list): |
| 165 | + tool_names = [t.name for t in cached_dbsql_tools_list] |
| 166 | + for expected in ["execute_sql", "execute_sql_read_only", "poll_sql_result"]: |
| 167 | + assert expected in tool_names, f"Expected tool '{expected}' not found in {tool_names}" |
| 168 | + |
| 169 | + def test_call_tool_execute_sql_read_only(self, dbsql_mcp_client, cached_dbsql_tools_list): |
| 170 | + """execute_sql_read_only with SHOW CATALOGS should return results.""" |
| 171 | + result = dbsql_mcp_client.call_tool("execute_sql_read_only", {"query": "SHOW CATALOGS"}) |
| 172 | + assert isinstance(result, CallToolResult) |
| 173 | + assert result.content, "SHOW CATALOGS should return content" |
| 174 | + assert len(result.content) > 0 |
| 175 | + |
| 176 | + |
| 177 | +# ============================================================================= |
| 178 | +# Raw streamable_http_client |
| 179 | +# ============================================================================= |
| 180 | + |
| 181 | + |
| 182 | +@pytest.mark.integration |
| 183 | +class TestRawStreamableHttpClient: |
| 184 | + """Verify DatabricksOAuthClientProvider works with the raw MCP SDK streamable_http_client. |
| 185 | +
|
| 186 | + This tests the low-level path: httpx.AsyncClient + DatabricksOAuthClientProvider |
| 187 | + + streamable_http_client + ClientSession, without going through DatabricksMCPClient. |
| 188 | + """ |
| 189 | + |
| 190 | + @pytest.mark.asyncio |
| 191 | + async def test_uc_function_list_and_call(self, uc_function_url, workspace_client): |
| 192 | + """list_tools + call_tool via raw streamable_http_client for UC functions.""" |
| 193 | + async with raw_mcp_session(uc_function_url, workspace_client) as session: |
| 194 | + tools_response = await session.list_tools() |
| 195 | + tools = tools_response.tools |
| 196 | + assert len(tools) > 0 |
| 197 | + tool_names = [t.name for t in tools] |
| 198 | + assert any("echo_message" in name for name in tool_names) |
| 199 | + |
| 200 | + tool_name = next(n for n in tool_names if "echo_message" in n) |
| 201 | + result = await session.call_tool(tool_name, {"message": "raw_client_test"}) |
| 202 | + assert result.content |
| 203 | + first = result.content[0] |
| 204 | + assert hasattr(first, "text") |
| 205 | + assert "raw_client_test" in str(first.text) |
| 206 | + |
| 207 | + @pytest.mark.asyncio |
| 208 | + async def test_vs_list_tools(self, vs_mcp_url, workspace_client): |
| 209 | + """list_tools via raw streamable_http_client for Vector Search.""" |
| 210 | + async with raw_mcp_session(vs_mcp_url, workspace_client) as session: |
| 211 | + tools_response = await session.list_tools() |
| 212 | + assert len(tools_response.tools) > 0 |
| 213 | + |
| 214 | + @pytest.mark.asyncio |
| 215 | + async def test_dbsql_list_and_call(self, dbsql_mcp_url, workspace_client): |
| 216 | + """list_tools + call_tool via raw streamable_http_client for DBSQL.""" |
| 217 | + async with raw_mcp_session(dbsql_mcp_url, workspace_client) as session: |
| 218 | + tools_response = await session.list_tools() |
| 219 | + tools = tools_response.tools |
| 220 | + tool_names = [t.name for t in tools] |
| 221 | + assert "execute_sql_read_only" in tool_names |
| 222 | + |
| 223 | + result = await session.call_tool("execute_sql_read_only", {"query": "SHOW CATALOGS"}) |
| 224 | + assert result.content |
| 225 | + assert len(result.content) > 0 |
| 226 | + |
| 227 | + @pytest.mark.asyncio |
| 228 | + async def test_genie_list_and_call(self, genie_mcp_url, workspace_client): |
| 229 | + """list_tools + call_tool via raw streamable_http_client for Genie.""" |
| 230 | + async with raw_mcp_session(genie_mcp_url, workspace_client) as session: |
| 231 | + tools_response = await session.list_tools() |
| 232 | + tools = tools_response.tools |
| 233 | + assert len(tools) > 0 |
| 234 | + |
| 235 | + tool = tools[0] |
| 236 | + properties = tool.inputSchema.get("properties", {}) |
| 237 | + param_name = next(iter(properties), "query") |
| 238 | + result = await session.call_tool(tool.name, {param_name: "How many rows are there?"}) |
| 239 | + assert result.content |
| 240 | + assert len(result.content) > 0 |
| 241 | + |
| 242 | + |
129 | 243 | # ============================================================================= |
130 | 244 | # Error paths |
131 | 245 | # ============================================================================= |
|
0 commit comments