Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ dremio:
project_id: <string> # Optional: Project ID for Dremio Cloud
enable_search: <bool> # Optional: Enable semantic search
allow_dml: <bool> # Optional: Allow MCP Server to create views in Dremio
wlm: # Optional: WLM settings, for running MCP server in stdio mode
engine_name: <string> # Optional: Direct all SQL to this engine if set
```

URI can be specified as:
Expand Down Expand Up @@ -157,9 +159,9 @@ tools:
Settings can be configured using environment variables with nested delimiter '\_':

```bash
DREMIO_URI="https://api.dremio.cloud"
DREMIO_PAT="your-pat-here"
TOOLS_SERVER_MODE="FOR_SELF"
DREMIOAI_DREMIO__URI="https://api.dremio.cloud"
DREMIOAI_DREMIO__PAT="your-pat-here"
DREMIOAI_TOOLS__SERVER_MODE="FOR_SELF"
```

### Programmatic Configuration
Expand Down
18 changes: 10 additions & 8 deletions src/dremioai/api/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def run_catalog(
# _qg = "Query / Job ID "
@sql_app.command("run")
def run_sql(
uri: Annotated[str, Option(envvar="DREMIO_URI", show_envvar=True, default=...)],
project_id: Annotated[
str, Option(envvar="DREMIO_PROJECT_ID", show_envvar=True, default=...)
],
pat: Annotated[str, Option(envvar="DREMIO_PAT", show_envvar=True, default=...)],
query: Annotated[
Optional[str],
Option(
Expand All @@ -88,17 +83,24 @@ def run_sql(
use_df: Annotated[
Optional[bool], Option(help="Convert results to pandas dataframe")
] = False,
engine_name: Annotated[
Optional[str], Option(help="The engine name to run the query on")
] = None,
):
if query is None and job_id is None:
raise BadParameter("Either query or job_id must be provided")

if query is not None:
query = Path(query[1:]).read_text().strip() if query.startswith("@") else query
query = f"/* dremioai: submitter=cli */\n{query}"
result = asyncio.run(sql.run_query(uri, pat, project_id, query, as_df=use_df))
query = f"/* dremioai: submitter=cli {'engine=' + engine_name if engine_name is not None else ''} */\n{query}"
if engine_name is not None:
query = sql.Query(sql=query, engineName=engine_name)
result = asyncio.run(sql.run_query(query, use_df=use_df))
else:
result = asyncio.run(
sql.get_results(project_id, job_id, as_df=use_df, uri=uri, pat=pat)
sql.get_results(
settings.instance().dremio.project_id, job_id, use_df=use_df
)
)

pp(result if use_df else [r for jr in result for r in jr.rows])
Expand Down
26 changes: 17 additions & 9 deletions src/dremioai/api/cli/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from typer import Option, Argument, Typer
import asyncio
from rich import print as pp
from dremioai.api.cli import engines
from dremioai.api.dremio import engines
from dremioai.config import settings

app = Typer(
no_args_is_help=True,
Expand All @@ -30,22 +31,23 @@

@app.command("list")
def elist(
uri: Annotated[str, Option(envvar="DREMIO_URI", default=...)],
project_id: Annotated[str, Option(envvar="DREMIO_PROJECT_ID", default=...)],
pat: Annotated[str, Option(envvar="DREMIO_PAT", default=...)],
use_df: Annotated[
Optional[bool], Option(help="Convert results to pandas dataframe")
] = False,
):
result = asyncio.run(engines.get_engines(uri, pat, project_id, use_df=use_df))
result = asyncio.run(
engines.get_engines(
settings.instance().dremio.uri,
settings.instance().dremio.pat,
settings.instance().dremio.project_id,
use_df=use_df,
)
)
pp(result)


@app.command("get")
def eget(
uri: Annotated[str, Option(envvar="DREMIO_URI", default=...)],
project_id: Annotated[str, Option(envvar="DREMIO_PROJECT_ID", default=...)],
pat: Annotated[str, Option(envvar="DREMIO_PAT", default=...)],
engine_ids: Annotated[
List[str], Argument(help="Engine IDs to retrieve details for")
],
Expand All @@ -54,6 +56,12 @@ def eget(
] = False,
):
result = asyncio.run(
engines.get_engines(uri, pat, project_id, engine_ids=engine_ids, use_df=use_df)
engines.get_engines(
settings.instance().dremio.uri,
settings.instance().dremio.pat,
settings.instance().dremio.project_id,
engine_ids=engine_ids,
use_df=use_df,
)
)
pp(result)
12 changes: 10 additions & 2 deletions src/dremioai/api/dremio/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ArcticSource(BaseModel):
class Query(BaseModel):
sql: str = Field(..., alias="sql")
context: Optional[List[str]] = None
engine_name: Optional[str] = Field(default=None, alias="engineName")
references: Optional[Dict[str, ArcticSource]] = None


Expand Down Expand Up @@ -233,11 +234,18 @@ async def run_query(
) -> Union[JobResultsWrapper, pd.DataFrame]:
client = AsyncHttpClient()
if not isinstance(query, Query):
query = Query(sql=query)
engine_name = (
settings.instance().dremio.wlm.engine_name
if settings.instance().dremio.wlm is not None
else None
)
query = Query(sql=query, engineName=engine_name)

project_id = settings.instance().dremio.project_id
endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
qs: QuerySubmission = await client.post(
f"{endpoint}/sql", body=query.model_dump(), deser=QuerySubmission
f"{endpoint}/sql",
body=query.model_dump(by_alias=True, exclude_none=True),
deser=QuerySubmission,
)
return await get_results(project_id, qs, use_df=use_df, client=client)
6 changes: 6 additions & 0 deletions src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def has_expired(self) -> bool:
return self.expiry is not None and self.expiry < datetime.now()


class Wlm(BaseModel):
engine_name: Optional[str] = None
model_config = ConfigDict(validate_assignment=True)


class Dremio(BaseModel):
uri: Annotated[
Union[str, HttpUrl, DremioCloudUri], AfterValidator(_resolve_dremio_uri)
Expand All @@ -146,6 +151,7 @@ class Dremio(BaseModel):
oauth2: Optional[OAuth2] = None
allow_dml: Optional[bool] = False
auth_issuer_uri_override: Optional[str] = None
wlm: Optional[Wlm] = None
model_config = ConfigDict(validate_assignment=True)

@field_serializer("raw_pat")
Expand Down
30 changes: 16 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,27 @@ class StreamableMcpServerFixture(NamedTuple):

@contextlib.asynccontextmanager
async def http_streamable_mcp_server(
logging_server: LoggingServerFixture, logging_level: str, project_id: str = None
logging_server: LoggingServerFixture,
logging_level: str,
project_id: str = None,
wlm_engine: str = None,
) -> AsyncGenerator[StreamableMcpServerFixture]:
old = settings.instance()
sf = None
try:
settings.configure(force=True)
settings._settings.set(
settings.Settings.model_validate(
{
"dremio": {
"uri": logging_server.url,
"project_id": uuid.uuid4(),
"pat": "test-pat",
"enable_search": True,
},
"tools": {"server_mode": ToolType.FOR_DATA_PATTERNS.name},
}
)
)
config = {
"dremio": {
"uri": logging_server.url,
"project_id": uuid.uuid4(),
"pat": "test-pat",
"enable_search": True,
},
"tools": {"server_mode": ToolType.FOR_DATA_PATTERNS.name},
}
if wlm_engine:
config["dremio"]["wlm"] = {"engine_name": wlm_engine}
settings._settings.set(settings.Settings.model_validate(config))
settings.write_settings()
port = random.randrange(9000, 12000)
set_level(logging_level.upper())
Expand Down
1 change: 0 additions & 1 deletion tests/e2e/test_e2e_pat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ async def test_tool_pat(mock_config_dir, logging_server, logging_level, project_
assert result.structuredContent["result"]["result"][0]["test_column"] == 1
from rich import print as pp

pp(logging_server.logs())
for le in logging_server.logs():
assert (
le.headers.get("authorization") == "Bearer my-token"
Expand Down
42 changes: 42 additions & 0 deletions tests/e2e/test_mcp_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import pytest
from conftest import http_streamable_client_server, http_streamable_mcp_server
from mcp.types import CallToolResult
from rich import print as pp

from dremioai.tools.tools import get_tools
from dremioai.config import settings
Expand All @@ -35,6 +37,7 @@ async def test_basic(mock_config_dir, logging_server, logging_level):
t.__name__ for t in get_tools(For=settings.instance().tools.server_mode)
}


@pytest.mark.asyncio
async def test_healthz(mock_config_dir, logging_server, logging_level):
async with http_streamable_mcp_server(logging_server, logging_level) as sf:
Expand All @@ -45,3 +48,42 @@ async def test_healthz(mock_config_dir, logging_server, logging_level):
assert (
r.status_code == 200
), f"/healthz failed with {r.text}, {r.status_code}"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"engine_name",
[
pytest.param(None, id="no_engine_name"),
pytest.param("test-engine"),
pytest.param("test-engine-2"),
],
)
async def test_wlm_engine_name(
mock_config_dir, logging_server, logging_level, engine_name
):
async with http_streamable_mcp_server(
logging_server, logging_level, wlm_engine=engine_name
) as sf:
async with http_streamable_client_server(
sf.mcp_server, token="my-token"
) as session:
result: CallToolResult = await session.call_tool(
"RunSqlQuery", {"s": "SELECT 1"}
)
assert (
result is not None
and result.structuredContent is not None
and result.structuredContent["result"]["result"][0]["test_column"] == 1
), f"Error running tool {result}"

for le in logging_server.logs():
if le.path.endswith("/sql") and le.method == "POST":
if engine_name is None:
assert (
le.json.get("engineName") is None
), f"{le.json} has engineName"
else:
assert (
le.json.get("engineName") == engine_name
), f"{le.json} does not have the right engineName"
6 changes: 6 additions & 0 deletions tests/mocks/http_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class LogEntry(BaseModel):
query_params: Dict[str, Any]
headers: Dict[str, Any]
response_status: Optional[int] = None
json: Optional[Dict[str, Any]] = None
model_config = ConfigDict(validate_assignment=True)


Expand All @@ -218,13 +219,18 @@ def __init__(self, app, log_file: Union[str, io.TextIOWrapper, Path]):

async def dispatch(self, request: Request, call_next):
# Capture request details
try:
body = await request.json()
except Exception:
body = None
log_entry = LogEntry.model_validate(
{
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"query_params": dict(request.query_params),
"headers": dict(request.headers),
"json": body,
}
)

Expand Down