Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/dremioai/api/dremio/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ArcticSource(BaseModel):
class Query(BaseModel):
sql: str = Field(..., alias="sql")
context: Optional[List[str]] = None
engine_name: Optional[str] = Field(default=None, alias="engineName")
references: Optional[Dict[str, ArcticSource]] = None


Expand Down Expand Up @@ -233,11 +234,18 @@ async def run_query(
) -> Union[JobResultsWrapper, pd.DataFrame]:
client = AsyncHttpClient()
if not isinstance(query, Query):
query = Query(sql=query)
engine_name = (
settings.instance().dremio.wlm.engine_name
if settings.instance().dremio.wlm is not None
else None
)
query = Query(sql=query, engineName=engine_name)

project_id = settings.instance().dremio.project_id
endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
qs: QuerySubmission = await client.post(
f"{endpoint}/sql", body=query.model_dump(), deser=QuerySubmission
f"{endpoint}/sql",
body=query.model_dump(by_alias=True, exclude_none=True),
deser=QuerySubmission,
)
return await get_results(project_id, qs, use_df=use_df, client=client)
6 changes: 6 additions & 0 deletions src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def has_expired(self) -> bool:
return self.expiry is not None and self.expiry < datetime.now()


class Wlm(BaseModel):
engine_name: Optional[str] = None
model_config = ConfigDict(validate_assignment=True)


class Dremio(BaseModel):
uri: Annotated[
Union[str, HttpUrl, DremioCloudUri], AfterValidator(_resolve_dremio_uri)
Expand All @@ -146,6 +151,7 @@ class Dremio(BaseModel):
oauth2: Optional[OAuth2] = None
allow_dml: Optional[bool] = False
auth_issuer_uri_override: Optional[str] = None
wlm: Optional[Wlm] = None
model_config = ConfigDict(validate_assignment=True)

@field_serializer("raw_pat")
Expand Down
30 changes: 16 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,27 @@ class StreamableMcpServerFixture(NamedTuple):

@contextlib.asynccontextmanager
async def http_streamable_mcp_server(
logging_server: LoggingServerFixture, logging_level: str, project_id: str = None
logging_server: LoggingServerFixture,
logging_level: str,
project_id: str = None,
wlm_engine: str = None,
) -> AsyncGenerator[StreamableMcpServerFixture]:
old = settings.instance()
sf = None
try:
settings.configure(force=True)
settings._settings.set(
settings.Settings.model_validate(
{
"dremio": {
"uri": logging_server.url,
"project_id": uuid.uuid4(),
"pat": "test-pat",
"enable_search": True,
},
"tools": {"server_mode": ToolType.FOR_DATA_PATTERNS.name},
}
)
)
config = {
"dremio": {
"uri": logging_server.url,
"project_id": uuid.uuid4(),
"pat": "test-pat",
"enable_search": True,
},
"tools": {"server_mode": ToolType.FOR_DATA_PATTERNS.name},
}
if wlm_engine:
config["dremio"]["wlm"] = {"engine_name": wlm_engine}
settings._settings.set(settings.Settings.model_validate(config))
settings.write_settings()
port = random.randrange(9000, 12000)
set_level(logging_level.upper())
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/test_e2e_pat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ async def test_tool_pat(mock_config_dir, logging_server, logging_level, project_
assert result.structuredContent["result"]["result"][0]["test_column"] == 1
from rich import print as pp

pp(logging_server.logs())
for le in logging_server.logs():
assert (
le.headers.get("authorization") == "Bearer my-token"
Expand Down
42 changes: 42 additions & 0 deletions tests/e2e/test_mcp_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import pytest
from conftest import http_streamable_client_server, http_streamable_mcp_server
from mcp.types import CallToolResult
from rich import print as pp

from dremioai.tools.tools import get_tools
from dremioai.config import settings
Expand All @@ -35,6 +37,7 @@ async def test_basic(mock_config_dir, logging_server, logging_level):
t.__name__ for t in get_tools(For=settings.instance().tools.server_mode)
}


@pytest.mark.asyncio
async def test_healthz(mock_config_dir, logging_server, logging_level):
async with http_streamable_mcp_server(logging_server, logging_level) as sf:
Expand All @@ -45,3 +48,42 @@ async def test_healthz(mock_config_dir, logging_server, logging_level):
assert (
r.status_code == 200
), f"/healthz failed with {r.text}, {r.status_code}"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"engine_name",
[
pytest.param(None, id="no_engine_name"),
pytest.param("test-engine"),
pytest.param("test-engine-2"),
],
)
async def test_wlm_engine_name(
mock_config_dir, logging_server, logging_level, engine_name
):
async with http_streamable_mcp_server(
logging_server, logging_level, wlm_engine=engine_name
) as sf:
async with http_streamable_client_server(
sf.mcp_server, token="my-token"
) as session:
result: CallToolResult = await session.call_tool(
"RunSqlQuery", {"s": "SELECT 1"}
)
assert (
result is not None
and result.structuredContent is not None
and result.structuredContent["result"]["result"][0]["test_column"] == 1
), f"Error running tool {result}"

for le in logging_server.logs():
if le.path.endswith("/sql") and le.method == "POST":
if engine_name is None:
assert (
le.json.get("engineName") is None
), f"{le.json} has engineName"
else:
assert (
le.json.get("engineName") == engine_name
), f"{le.json} does not have the right engineName"
6 changes: 6 additions & 0 deletions tests/mocks/http_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class LogEntry(BaseModel):
query_params: Dict[str, Any]
headers: Dict[str, Any]
response_status: Optional[int] = None
json: Optional[Dict[str, Any]] = None
model_config = ConfigDict(validate_assignment=True)


Expand All @@ -218,13 +219,18 @@ def __init__(self, app, log_file: Union[str, io.TextIOWrapper, Path]):

async def dispatch(self, request: Request, call_next):
# Capture request details
try:
body = await request.json()
except Exception:
body = None
log_entry = LogEntry.model_validate(
{
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"headers": dict(request.headers),
"json": body,
}
)

Expand Down