-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathconftest.py
More file actions
265 lines (193 loc) · 9.09 KB
/
conftest.py
File metadata and controls
265 lines (193 loc) · 9.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""
Shared fixtures for MCP integration tests.
These tests require a live Databricks workspace with a pre-created UC function.
They are NOT run by default — set RUN_MCP_INTEGRATION_TESTS=1 to enable.
Prerequisites:
The test UC function must exist in the workspace.
Environment Variables:
======================
Required:
RUN_MCP_INTEGRATION_TESTS - Set to "1" to enable MCP integration tests
DATABRICKS_HOST - Workspace URL
DATABRICKS_CLIENT_ID - Service principal client ID
DATABRICKS_CLIENT_SECRET - Service principal client secret
"""
from __future__ import annotations
import os
import pytest
from mcp.shared.exceptions import McpError
from databricks_mcp import DatabricksMCPClient
def _find_mcp_error(exc_group: ExceptionGroup) -> McpError | None: # ty: ignore[unresolved-reference]
"""Recursively unwrap nested ExceptionGroups to find a McpError."""
for exc in exc_group.exceptions:
if isinstance(exc, McpError):
return exc
if isinstance(exc, ExceptionGroup): # ty: ignore[unresolved-reference]
found = _find_mcp_error(exc)
if found:
return found
return None
def _skip_if_not_found(exc_group: ExceptionGroup, context: str) -> None: # ty: ignore[unresolved-reference]
"""Skip the test if the McpError indicates a missing resource, otherwise re-raise."""
mcp_error = _find_mcp_error(exc_group)
if mcp_error:
msg = str(mcp_error)
if "NOT_FOUND" in msg or "not found" in msg.lower():
pytest.skip(f"{context}: {mcp_error}")
raise exc_group
# =============================================================================
# Constants
# =============================================================================
CATALOG = "integration_testing"
SCHEMA = "databricks_ai_bridge_mcp_test"
FUNCTION_NAME = "echo_message"
VS_SCHEMA = "databricks_ai_bridge_vs_test"
VS_INDEX = "delta_sync_managed"
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def workspace_client():
"""
Create a WorkspaceClient using environment variables.
The SDK auto-detects auth from env vars (e.g. DATABRICKS_HOST,
DATABRICKS_CLIENT_ID, DATABRICKS_CLIENT_SECRET for OAuth M2M).
"""
from databricks.sdk import WorkspaceClient
return WorkspaceClient()
@pytest.fixture(scope="session")
def uc_function_url(workspace_client):
"""Construct MCP URL for the single test UC function."""
base_url = workspace_client.config.host
return f"{base_url}/api/2.0/mcp/functions/{CATALOG}/{SCHEMA}/{FUNCTION_NAME}"
@pytest.fixture(scope="session")
def uc_schema_url(workspace_client):
"""Construct MCP URL for the full test schema (all functions)."""
base_url = workspace_client.config.host
return f"{base_url}/api/2.0/mcp/functions/{CATALOG}/{SCHEMA}"
@pytest.fixture(scope="session")
def mcp_client(uc_function_url, workspace_client):
"""DatabricksMCPClient pointed at the single test UC function."""
return DatabricksMCPClient(uc_function_url, workspace_client)
@pytest.fixture(scope="session")
def schema_mcp_client(uc_schema_url, workspace_client):
"""DatabricksMCPClient pointed at the full test schema."""
return DatabricksMCPClient(uc_schema_url, workspace_client)
@pytest.fixture(scope="session")
def cached_tools_list(mcp_client):
"""
Cache the list_tools() result for the session to minimize API calls.
Skips all dependent tests if the function doesn't exist.
"""
try:
tools = mcp_client.list_tools()
except ExceptionGroup as e: # ty: ignore[unresolved-reference]
_skip_if_not_found(e, "UC function not found in workspace")
assert tools, "list_tools() returned no tools — is the test function set up?"
return tools
@pytest.fixture(scope="session")
def cached_call_result(mcp_client, cached_tools_list):
"""
Cache a call_tool() result for the session.
Uses the first tool from cached_tools_list.
"""
tool_name = cached_tools_list[0].name
return mcp_client.call_tool(tool_name, {"message": "hello"})
# =============================================================================
# Vector Search Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def vs_mcp_url(workspace_client):
"""Construct MCP URL for a single VS index."""
base_url = workspace_client.config.host
return f"{base_url}/api/2.0/mcp/vector-search/{CATALOG}/{VS_SCHEMA}/{VS_INDEX}"
@pytest.fixture(scope="session")
def vs_schema_mcp_url(workspace_client):
"""Construct MCP URL for all VS indexes in a schema."""
base_url = workspace_client.config.host
return f"{base_url}/api/2.0/mcp/vector-search/{CATALOG}/{VS_SCHEMA}"
@pytest.fixture(scope="session")
def vs_mcp_client(vs_mcp_url, workspace_client):
"""DatabricksMCPClient pointed at a single VS index."""
return DatabricksMCPClient(vs_mcp_url, workspace_client)
@pytest.fixture(scope="session")
def vs_schema_mcp_client(vs_schema_mcp_url, workspace_client):
"""DatabricksMCPClient pointed at all VS indexes in a schema."""
return DatabricksMCPClient(vs_schema_mcp_url, workspace_client)
@pytest.fixture(scope="session")
def cached_vs_tools_list(vs_mcp_client):
"""Cache the VS list_tools() result; skip if VS MCP endpoint unavailable."""
try:
tools = vs_mcp_client.list_tools()
except ExceptionGroup as e: # ty: ignore[unresolved-reference]
_skip_if_not_found(e, "VS MCP endpoint not available in workspace")
assert tools, "VS list_tools() returned no tools — is the VS index set up?"
return tools
@pytest.fixture(scope="session")
def cached_vs_call_result(vs_mcp_client, cached_vs_tools_list):
"""Cache a VS call_tool() result for the session."""
tool = cached_vs_tools_list[0]
properties = tool.inputSchema.get("properties", {})
param_name = next(iter(properties), "query")
return vs_mcp_client.call_tool(tool.name, {param_name: "test"})
# =============================================================================
# DBSQL Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def dbsql_mcp_url(workspace_client):
"""Construct MCP URL for the DBSQL server."""
base_url = workspace_client.config.host
return f"{base_url}/api/2.0/mcp/sql"
@pytest.fixture(scope="session")
def dbsql_mcp_client(dbsql_mcp_url, workspace_client):
"""DatabricksMCPClient pointed at the DBSQL server."""
return DatabricksMCPClient(dbsql_mcp_url, workspace_client)
@pytest.fixture(scope="session")
def cached_dbsql_tools_list(dbsql_mcp_client):
"""Cache the DBSQL list_tools() result."""
tools = dbsql_mcp_client.list_tools()
assert tools, "DBSQL list_tools() returned no tools"
return tools
# =============================================================================
# Genie Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def genie_space_id():
"""Get the Genie Space ID from the GENIE_SPACE_ID environment variable."""
space_id = os.environ.get("GENIE_SPACE_ID")
if not space_id:
pytest.skip("GENIE_SPACE_ID environment variable not set")
return space_id
@pytest.fixture(scope="session")
def genie_mcp_url(workspace_client, genie_space_id):
"""Construct MCP URL for a Genie space."""
base_url = workspace_client.config.host
return f"{base_url}/api/2.0/mcp/genie/{genie_space_id}"
@pytest.fixture(scope="session")
def genie_mcp_client(genie_mcp_url, workspace_client):
"""DatabricksMCPClient pointed at a Genie space."""
return DatabricksMCPClient(genie_mcp_url, workspace_client)
@pytest.fixture(scope="session")
def cached_genie_tools_list(genie_mcp_client):
"""Cache the Genie list_tools() result; skip if Genie MCP endpoint unavailable."""
try:
tools = genie_mcp_client.list_tools()
except ExceptionGroup as e: # ty: ignore[unresolved-reference]
_skip_if_not_found(e, "Genie MCP endpoint not available in workspace")
assert tools, "Genie list_tools() returned no tools — is the Genie space set up?"
return tools
@pytest.fixture(scope="session")
def cached_genie_call_result(genie_mcp_client, cached_genie_tools_list):
"""Cache a Genie call_tool() result for the session."""
tool = cached_genie_tools_list[0]
# Extract the query parameter name from the tool's inputSchema
# rather than hardcoding it, since the server defines the schema.
properties = tool.inputSchema.get("properties", {})
param_name = next(iter(properties), "query")
return genie_mcp_client.call_tool(tool.name, {param_name: "How many rows are there?"})
# =============================================================================
# Markers
# =============================================================================
def pytest_configure(config):
"""Register custom markers."""
config.addinivalue_line("markers", "integration: mark test as integration test")