Skip to content

Commit 52036de

Browse files
authored
Fix langchain integration test failures + add DBSQL/streamable MCP tests (#369)
1 parent 55b84a2 commit 52036de

File tree

6 files changed

+217
-89
lines changed

6 files changed

+217
-89
lines changed

databricks_mcp/tests/integration_tests/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,32 @@ def cached_vs_call_result(vs_mcp_client, cached_vs_tools_list):
180180
return vs_mcp_client.call_tool(tool.name, {param_name: "test"})
181181

182182

183+
# =============================================================================
184+
# DBSQL Fixtures
185+
# =============================================================================
186+
187+
188+
@pytest.fixture(scope="session")
189+
def dbsql_mcp_url(workspace_client):
190+
"""Construct MCP URL for the DBSQL server."""
191+
base_url = workspace_client.config.host
192+
return f"{base_url}/api/2.0/mcp/sql"
193+
194+
195+
@pytest.fixture(scope="session")
196+
def dbsql_mcp_client(dbsql_mcp_url, workspace_client):
197+
"""DatabricksMCPClient pointed at the DBSQL server."""
198+
return DatabricksMCPClient(dbsql_mcp_url, workspace_client)
199+
200+
201+
@pytest.fixture(scope="session")
202+
def cached_dbsql_tools_list(dbsql_mcp_client):
203+
"""Cache the DBSQL list_tools() result."""
204+
tools = dbsql_mcp_client.list_tools()
205+
assert tools, "DBSQL list_tools() returned no tools"
206+
return tools
207+
208+
183209
# =============================================================================
184210
# Genie Fixtures
185211
# =============================================================================

databricks_mcp/tests/integration_tests/test_mcp_core.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,38 @@
88
from __future__ import annotations
99

1010
import os
11+
from contextlib import asynccontextmanager
1112

1213
import pytest
1314
from conftest import _skip_if_not_found
1415
from mcp.shared.exceptions import McpError
1516
from mcp.types import CallToolResult
1617

18+
19+
@asynccontextmanager
20+
async def raw_mcp_session(url, workspace_client):
21+
"""Create a raw MCP ClientSession using streamable_http_client with Databricks OAuth."""
22+
import httpx
23+
from mcp import ClientSession
24+
from mcp.client.streamable_http import streamable_http_client
25+
26+
from databricks_mcp import DatabricksOAuthClientProvider
27+
28+
async with httpx.AsyncClient(
29+
auth=DatabricksOAuthClientProvider(workspace_client),
30+
follow_redirects=True,
31+
timeout=httpx.Timeout(120.0, read=120.0),
32+
) as http_client:
33+
async with streamable_http_client(url, http_client=http_client) as (
34+
read_stream,
35+
write_stream,
36+
_,
37+
):
38+
async with ClientSession(read_stream, write_stream) as session:
39+
await session.initialize()
40+
yield session
41+
42+
1743
from databricks_mcp import DatabricksMCPClient
1844

1945
pytestmark = pytest.mark.skipif(
@@ -126,6 +152,94 @@ def test_call_tool_returns_result_with_content(self, cached_genie_call_result):
126152
assert len(cached_genie_call_result.content) > 0
127153

128154

155+
# =============================================================================
156+
# DBSQL
157+
# =============================================================================
158+
159+
160+
@pytest.mark.integration
161+
class TestMCPClientDBSQL:
162+
"""Verify list_tools() and call_tool() against a live DBSQL MCP server."""
163+
164+
def test_list_tools_returns_expected_tools(self, cached_dbsql_tools_list):
165+
tool_names = [t.name for t in cached_dbsql_tools_list]
166+
for expected in ["execute_sql", "execute_sql_read_only", "poll_sql_result"]:
167+
assert expected in tool_names, f"Expected tool '{expected}' not found in {tool_names}"
168+
169+
def test_call_tool_execute_sql_read_only(self, dbsql_mcp_client, cached_dbsql_tools_list):
170+
"""execute_sql_read_only with SHOW CATALOGS should return results."""
171+
result = dbsql_mcp_client.call_tool("execute_sql_read_only", {"query": "SHOW CATALOGS"})
172+
assert isinstance(result, CallToolResult)
173+
assert result.content, "SHOW CATALOGS should return content"
174+
assert len(result.content) > 0
175+
176+
177+
# =============================================================================
178+
# Raw streamable_http_client
179+
# =============================================================================
180+
181+
182+
@pytest.mark.integration
183+
class TestRawStreamableHttpClient:
184+
"""Verify DatabricksOAuthClientProvider works with the raw MCP SDK streamable_http_client.
185+
186+
This tests the low-level path: httpx.AsyncClient + DatabricksOAuthClientProvider
187+
+ streamable_http_client + ClientSession, without going through DatabricksMCPClient.
188+
"""
189+
190+
@pytest.mark.asyncio
191+
async def test_uc_function_list_and_call(self, uc_function_url, workspace_client):
192+
"""list_tools + call_tool via raw streamable_http_client for UC functions."""
193+
async with raw_mcp_session(uc_function_url, workspace_client) as session:
194+
tools_response = await session.list_tools()
195+
tools = tools_response.tools
196+
assert len(tools) > 0
197+
tool_names = [t.name for t in tools]
198+
assert any("echo_message" in name for name in tool_names)
199+
200+
tool_name = next(n for n in tool_names if "echo_message" in n)
201+
result = await session.call_tool(tool_name, {"message": "raw_client_test"})
202+
assert result.content
203+
first = result.content[0]
204+
assert hasattr(first, "text")
205+
assert "raw_client_test" in str(first.text)
206+
207+
@pytest.mark.asyncio
208+
async def test_vs_list_tools(self, vs_mcp_url, workspace_client):
209+
"""list_tools via raw streamable_http_client for Vector Search."""
210+
async with raw_mcp_session(vs_mcp_url, workspace_client) as session:
211+
tools_response = await session.list_tools()
212+
assert len(tools_response.tools) > 0
213+
214+
@pytest.mark.asyncio
215+
async def test_dbsql_list_and_call(self, dbsql_mcp_url, workspace_client):
216+
"""list_tools + call_tool via raw streamable_http_client for DBSQL."""
217+
async with raw_mcp_session(dbsql_mcp_url, workspace_client) as session:
218+
tools_response = await session.list_tools()
219+
tools = tools_response.tools
220+
tool_names = [t.name for t in tools]
221+
assert "execute_sql_read_only" in tool_names
222+
223+
result = await session.call_tool("execute_sql_read_only", {"query": "SHOW CATALOGS"})
224+
assert result.content
225+
assert len(result.content) > 0
226+
227+
@pytest.mark.asyncio
228+
async def test_genie_list_and_call(self, genie_mcp_url, workspace_client):
229+
"""list_tools + call_tool via raw streamable_http_client for Genie."""
230+
async with raw_mcp_session(genie_mcp_url, workspace_client) as session:
231+
tools_response = await session.list_tools()
232+
tools = tools_response.tools
233+
assert len(tools) > 0
234+
235+
tool = tools[0]
236+
properties = tool.inputSchema.get("properties", {})
237+
param_name = next(iter(properties), "query")
238+
result = await session.call_tool(tool.name, {param_name: "How many rows are there?"})
239+
assert result.content
240+
assert len(result.content) > 0
241+
242+
129243
# =============================================================================
130244
# Error paths
131245
# =============================================================================

0 commit comments

Comments
 (0)