Skip to content
34 changes: 17 additions & 17 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,24 @@ jobs:
-e DREMIOAI_DREMIO__OAUTH_SUPPORTED=false \
dremio-mcp-server:${{ github.sha }} \
dremio-mcp-server tools list
- name: Start container
run: |
docker run \
--detach \
-e DREMIOAI_TOOLS__SERVER_MODE=FOR_DATA_PATTERNS \
-e DREMIOAI_DREMIO__URI=https://fake \
-e DREMIOAI_DREMIO__OAUTH_SUPPORTED=false \
-p 6789:6789 \
--name mcp \
--rm \
--network host \
dremio-mcp-server:${{ github.sha }} \
dremio-mcp-server run --enable-streaming-http --port 6789 --no-log-to-file \
--enable-json-logging
# - name: Start container
# run: |
# docker run \
# --detach \
# -e DREMIOAI_TOOLS__SERVER_MODE=FOR_DATA_PATTERNS \
# -e DREMIOAI_DREMIO__URI=https://fake \
# -e DREMIOAI_DREMIO__OAUTH_SUPPORTED=false \
# -p 6789:6789 \
# --name mcp \
# --rm \
# --network host \
# dremio-mcp-server:${{ github.sha }} \
# dremio-mcp-server run --enable-streaming-http --port 6789 --no-log-to-file \
# --enable-json-logging

- name: Test container
run: |
uv run python tests/stremable_http_cli.py --url http://127.0.0.1:6789/mcp --token fake
# - name: Test container
# run: |
# uv run python tests/stremable_http_cli.py --url http://127.0.0.1:6789/mcp --token fake
- id: tag
name: Create tag
run: |
Expand Down
1 change: 0 additions & 1 deletion src/dremioai/api/cli/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def login(
{"client_id": client_id}
)
oauth = get_oauth2_tokens()
oauth.update_settings()
pp(
settings.instance().dremio.model_dump(
exclude_none=True, mode="json", by_alias=True, exclude_unset=True
Expand Down
50 changes: 34 additions & 16 deletions src/dremioai/api/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from aiohttp import ClientSession, web
from threading import Thread
from secrets import token_urlsafe
Expand All @@ -7,6 +9,7 @@
from dremioai.config import settings
from datetime import datetime, timedelta
from importlib.resources import files
from dremioai.log import logger
import webbrowser
import asyncio

Expand Down Expand Up @@ -98,18 +101,30 @@ def get_pkce_pair(length=96):


class OAuth2:
def __init__(self):
if settings.instance().dremio.oauth2.client_id is None:
raise RuntimeError("oauth_client_id is not set in the config file")

base = urlparse(settings.instance().dremio.uri)
if base.netloc.startswith("api."):
base = base._replace(netloc=f"login.{base.netloc[4:]}")
url = base.geturl()

self.client_id = settings.instance().dremio.oauth2.client_id
self.authorize_url = f"{url}/oauth/authorize"
self.access_token_url = f"{url}/oauth/token"
def __init__(
self,
client_id: Optional[str] = None,
auth_url: Optional[str] = None,
token_url: Optional[str] = None,
):
if not auth_url:
if settings.instance().dremio.oauth2.client_id is None:
raise RuntimeError("oauth_client_id is not set in the config file")

base = urlparse(settings.instance().dremio.uri)
if base.netloc.startswith("api."):
base = base._replace(netloc=f"login.{base.netloc[4:]}")
url = base.geturl()

self.client_id = settings.instance().dremio.oauth2.client_id
self.authorize_url = f"{url}/oauth/authorize"
self.access_token_url = f"{url}/oauth/token"
else:
self.authorize_url = auth_url
self.access_token_url = token_url
self.client_id = client_id
if self.client_id is None:
raise RuntimeError("client_id is required")
self.redirect_port = 8976
self.scope = "dremio.all offline_access"
self.code_verifier, self.code_challenge = get_pkce_pair()
Expand All @@ -121,7 +136,6 @@ def __init__(self):
"code_challenge": self.code_challenge,
"code_challenge_method": "S256",
}
print(self.init_params)
self.oauth_redirect = OAuth2Redirect(
self.client_id,
self.code_verifier,
Expand All @@ -131,10 +145,14 @@ def __init__(self):
)


def get_oauth2_tokens() -> OAuth2Redirect:
def get_oauth2_tokens(
client_id: Optional[str] = None,
auth_url: Optional[str] = None,
token_url: Optional[str] = None,
) -> OAuth2Redirect:
# client_id = "311658a1-19ae-4851-b6a6-911c794312e2",
# client_id = "a3743893-d849-4c8a-893b-533dd457aac4"
oauth = OAuth2()
oauth = OAuth2(client_id, auth_url, token_url)
server_thread = Thread(
target=run_server,
daemon=True,
Expand All @@ -146,7 +164,7 @@ def get_oauth2_tokens() -> OAuth2Redirect:
print(f"Opening browser to {url}")
webbrowser.open(url)
server_thread.join()
print(
logger("oauth2").debug(
f"Access token: {oauth.oauth_redirect.access_token}\n"
f"Refresh token: {oauth.oauth_redirect.refresh_token}\n"
f"Expiry: {oauth.oauth_redirect.expiry}\n"
Expand Down
5 changes: 5 additions & 0 deletions src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from os import environ
from importlib.util import find_spec
from datetime import datetime
from dremioai import log

ProjectId = Union[UUID, Literal["DREMIO_DYNAMIC"]]

Expand Down Expand Up @@ -144,6 +145,7 @@ class Dremio(BaseModel):
)
oauth2: Optional[OAuth2] = None
allow_dml: Optional[bool] = False
auth_issuer_uri_override: Optional[str] = None
model_config = ConfigDict(validate_assignment=True)

@field_serializer("raw_pat")
Expand Down Expand Up @@ -186,11 +188,14 @@ def is_cloud(self) -> bool:

@property
def auth_issuer_uri(self) -> Optional[str]:
if self.auth_issuer_uri_override is not None:
return self.auth_issuer_uri_override
if self.is_cloud:
uri = urlparse(self.uri)
if uri.netloc.startswith("api."):
uri = uri._replace(netloc=f"login.{uri.netloc[4:]}")
return uri.geturl()
log.logger("settings").error("Oauth not supported for non-cloud instances")
return None

@property
Expand Down
20 changes: 17 additions & 3 deletions tests/config/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,14 @@ def test_env_file(mock_config_dir):


@pytest.mark.parametrize(
"uri,project_id,issuer,error",
"uri,project_id,issuer,error,iss_override",
[
pytest.param(
uri,
project_id,
iss,
project_id is None,
iss_override,
id=f"{label} with {plabel}",
)
for uri, iss, label in (
Expand All @@ -158,10 +159,23 @@ def test_env_file(mock_config_dir):
("DREMIO_DYNAMIC", "dynamic-project-id"),
(str(uuid.uuid4()), "project-id"),
)
for iss_override in (None, "https://my-override")
],
)
def test_auth_urls(uri: str, project_id: str | None, issuer: str, error: bool):
d = settings.Dremio.model_validate({"uri": uri, "project_id": project_id})
def test_auth_urls(
uri: str, project_id: str | None, issuer: str, error: bool, iss_override: str | None
):
d = settings.Dremio.model_validate(
{
"uri": uri,
"project_id": project_id,
"auth_issuer_uri_override": (
iss_override if iss_override and not error else None
),
}
)
if iss_override:
issuer = iss_override
auth = (f"{issuer}/oauth/authorize", f"{issuer}/oauth/token") if not error else None
issuer = issuer if not error else None
assert d.auth_issuer_uri == issuer
Expand Down
150 changes: 137 additions & 13 deletions tests/stremable_http_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,33 @@
"""

import asyncio
import functools
import random
import time
from contextlib import asynccontextmanager
from typing import Annotated, Optional, AsyncGenerator, Callable, Any
from urllib.parse import urlparse

from mcp import ClientSession, types
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthMetadata
from typer import Typer, Option
from rich import print as pp
from typing import Annotated, Optional
import requests

from dremioai import log
from dremioai.api.oauth2 import get_oauth2_tokens, OAuth2Redirect


def async_command(func: Callable) -> Callable:
"""Decorator to run async functions in Typer commands."""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return asyncio.run(func(*args, **kwargs))

return wrapper


app = Typer(
no_args_is_help=True,
Expand All @@ -20,27 +41,130 @@
context_settings=dict(help_option_names=["-h", "--help"]),
)

auth = Typer(
no_args_is_help=True,
name="auth",
help="Auth related sub commands",
context_settings=dict(help_option_names=["-h", "--help"]),
)


# Example usage and demonstration
async def cli(url, token):
async with streamablehttp_client(
url=url, headers={"Authorization": f"Bearer {token}"}
) as (read_stream, write_stream, gid):
def get_oauth_config(url: str) -> OAuthMetadata:
u = urlparse(url)
u = u._replace(path="/.well-known/oauth-authorization-server")
log.logger("auth").info(f"Checking auth for {u.geturl()}")
r = requests.get(u.geturl())
if r.status_code != 200:
pp(f"Cannot get oauth config: {u.geturl()}")
r.raise_for_status()
return OAuthMetadata.model_validate(r.json())


@auth.command("list")
def list_auth(
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
):
pp(get_oauth_config(url))


@auth.command("check")
def check_auth(
client_id: Annotated[str, Option(help="The client id to check")],
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
) -> OAuth2Redirect:
md = get_oauth_config(url)
oauth = get_oauth2_tokens(
client_id, str(md.authorization_endpoint), str(md.token_endpoint)
)
pp(oauth.access_token)
return oauth


cli = Typer(
no_args_is_help=True,
name="cli",
help="MCP client session related sub commands",
context_settings=dict(help_option_names=["-h", "--help"]),
)


@asynccontextmanager
async def mcp_client_session(
url: str, token: Optional[str] = None
) -> AsyncGenerator[ClientSession, None]:
headers = {"Authorization": f"Bearer {token}"} if token is not None else None
async with streamablehttp_client(url=url, headers=headers) as (
read_stream,
write_stream,
gid,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
for t in await session.list_tools():
pp(t)
yield session


@app.command()
def main(
token: Annotated[Optional[str], Option(help="The authorization token to use")],
@cli.command("list-tools")
@async_command
async def list_tools(
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
token: Annotated[
Optional[str], Option(help="The authorization token to use")
] = None,
):
asyncio.run(cli(url, token))
async with mcp_client_session(url, token) as session:
tools = await session.list_tools()
for tool in tools:
pp(tool)


@app.command("test", help="Run a quick smoketest for a deployed MCP server")
@async_command
async def run_test(
client_id: Annotated[str, Option(help="The OAuth client id")],
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
):
pp("Checking auth..", end=" ")
a = check_auth(client_id, url)
pp("[green]OK[/green]\nConnecting to server..")
async with mcp_client_session(url, a.access_token) as session:
tools = await session.list_tools()
pp([i.name for i in tools.tools])

pp("[green]OK[/green]\nCalling tool..")
n = int(time.time())
query = f"SELECT {n} as n"
result = await session.call_tool("RunSqlQuery", {"s": query})
result = result.structuredContent["result"]["result"]
pp(result)

query2 = f"""
SELECT query
FROM sys.project.jobs_recent
WHERE query_type = 'REST' and submitted_ts > CURRENT_TIMESTAMP() - INTERVAL '1' minute
and query like '/* dremioai: submitter=RunS%' and query like '%SELECT {n} as n';
"""
result = await session.call_tool("RunSqlQuery", {"s": query2})
result = result.structuredContent["result"]["result"]
pp(result)

if len(result) != 1:
pp("[red]FAIL[/red]")
pp("[green]OK[/green]")


# Add the CLI subcommand to the main app
app.add_typer(cli)
app.add_typer(auth)


if __name__ == "__main__":
log.configure(enable_json_logging=False, to_file=False)
app()