Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/dremioai/api/cli/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def login(
{"client_id": client_id}
)
oauth = get_oauth2_tokens()
oauth.update_settings()
pp(
settings.instance().dremio.model_dump(
exclude_none=True, mode="json", by_alias=True, exclude_unset=True
Expand Down
50 changes: 34 additions & 16 deletions src/dremioai/api/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from aiohttp import ClientSession, web
from threading import Thread
from secrets import token_urlsafe
Expand All @@ -7,6 +9,7 @@
from dremioai.config import settings
from datetime import datetime, timedelta
from importlib.resources import files
from dremioai.log import logger
import webbrowser
import asyncio

Expand Down Expand Up @@ -98,18 +101,30 @@ def get_pkce_pair(length=96):


class OAuth2:
def __init__(self):
if settings.instance().dremio.oauth2.client_id is None:
raise RuntimeError("oauth_client_id is not set in the config file")

base = urlparse(settings.instance().dremio.uri)
if base.netloc.startswith("api."):
base = base._replace(netloc=f"login.{base.netloc[4:]}")
url = base.geturl()

self.client_id = settings.instance().dremio.oauth2.client_id
self.authorize_url = f"{url}/oauth/authorize"
self.access_token_url = f"{url}/oauth/token"
def __init__(
self,
client_id: Optional[str] = None,
auth_url: Optional[str] = None,
token_url: Optional[str] = None,
):
if not auth_url:
if settings.instance().dremio.oauth2.client_id is None:
raise RuntimeError("oauth_client_id is not set in the config file")

base = urlparse(settings.instance().dremio.uri)
if base.netloc.startswith("api."):
base = base._replace(netloc=f"login.{base.netloc[4:]}")
url = base.geturl()

self.client_id = settings.instance().dremio.oauth2.client_id
self.authorize_url = f"{url}/oauth/authorize"
self.access_token_url = f"{url}/oauth/token"
else:
self.authorize_url = auth_url
self.access_token_url = token_url
self.client_id = client_id
if self.client_id is None:
raise RuntimeError("client_id is required")
self.redirect_port = 8976
self.scope = "dremio.all offline_access"
self.code_verifier, self.code_challenge = get_pkce_pair()
Expand All @@ -121,7 +136,6 @@ def __init__(self):
"code_challenge": self.code_challenge,
"code_challenge_method": "S256",
}
print(self.init_params)
self.oauth_redirect = OAuth2Redirect(
self.client_id,
self.code_verifier,
Expand All @@ -131,10 +145,14 @@ def __init__(self):
)


def get_oauth2_tokens() -> OAuth2Redirect:
def get_oauth2_tokens(
client_id: Optional[str] = None,
auth_url: Optional[str] = None,
token_url: Optional[str] = None,
) -> OAuth2Redirect:
# client_id = "311658a1-19ae-4851-b6a6-911c794312e2",
# client_id = "a3743893-d849-4c8a-893b-533dd457aac4"
oauth = OAuth2()
oauth = OAuth2(client_id, auth_url, token_url)
server_thread = Thread(
target=run_server,
daemon=True,
Expand All @@ -146,7 +164,7 @@ def get_oauth2_tokens() -> OAuth2Redirect:
print(f"Opening browser to {url}")
webbrowser.open(url)
server_thread.join()
print(
logger("oauth2").debug(
f"Access token: {oauth.oauth_redirect.access_token}\n"
f"Refresh token: {oauth.oauth_redirect.refresh_token}\n"
f"Expiry: {oauth.oauth_redirect.expiry}\n"
Expand Down
150 changes: 137 additions & 13 deletions tests/stremable_http_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,33 @@
"""

import asyncio
import functools
import random
import time
from contextlib import asynccontextmanager
from typing import Annotated, Optional, AsyncGenerator, Callable, Any
from urllib.parse import urlparse

from mcp import ClientSession, types
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthMetadata
from typer import Typer, Option
from rich import print as pp
from typing import Annotated, Optional
import requests

from dremioai import log
from dremioai.api.oauth2 import get_oauth2_tokens, OAuth2Redirect


def async_command(func: Callable) -> Callable:
"""Decorator to run async functions in Typer commands."""

@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return asyncio.run(func(*args, **kwargs))

return wrapper


app = Typer(
no_args_is_help=True,
Expand All @@ -20,27 +41,130 @@
context_settings=dict(help_option_names=["-h", "--help"]),
)

auth = Typer(
no_args_is_help=True,
name="auth",
help="Auth related sub commands",
context_settings=dict(help_option_names=["-h", "--help"]),
)


# Example usage and demonstration
async def cli(url, token):
async with streamablehttp_client(
url=url, headers={"Authorization": f"Bearer {token}"}
) as (read_stream, write_stream, gid):
def get_oauth_config(url: str) -> OAuthMetadata:
u = urlparse(url)
u = u._replace(path="/.well-known/oauth-authorization-server")
log.logger("auth").info(f"Checking auth for {u.geturl()}")
r = requests.get(u.geturl())
if r.status_code != 200:
pp(f"Cannot get oauth config: {u.geturl()}")
r.raise_for_status()
return OAuthMetadata.model_validate(r.json())


@auth.command("list")
def list_auth(
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
):
pp(get_oauth_config(url))


@auth.command("check")
def check_auth(
client_id: Annotated[str, Option(help="The client id to check")],
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
) -> OAuth2Redirect:
md = get_oauth_config(url)
oauth = get_oauth2_tokens(
client_id, str(md.authorization_endpoint), str(md.token_endpoint)
)
pp(oauth.access_token)
return oauth


cli = Typer(
no_args_is_help=True,
name="cli",
help="MCP client session related sub commands",
context_settings=dict(help_option_names=["-h", "--help"]),
)


@asynccontextmanager
async def mcp_client_session(
url: str, token: Optional[str] = None
) -> AsyncGenerator[ClientSession, None]:
headers = {"Authorization": f"Bearer {token}"} if token is not None else None
async with streamablehttp_client(url=url, headers=headers) as (
read_stream,
write_stream,
gid,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
for t in await session.list_tools():
pp(t)
yield session


@app.command()
def main(
token: Annotated[Optional[str], Option(help="The authorization token to use")],
@cli.command("list-tools")
@async_command
async def list_tools(
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
token: Annotated[
Optional[str], Option(help="The authorization token to use")
] = None,
):
asyncio.run(cli(url, token))
async with mcp_client_session(url, token) as session:
tools = await session.list_tools()
for tool in tools:
pp(tool)


@app.command("test", help="Run a quick smoketest for a deployed MCP server")
@async_command
async def run_test(
client_id: Annotated[str, Option(help="The OAuth client id")],
url: Annotated[
Optional[str], Option(help="The URL of the MCP server")
] = "http://127.0.0.1:8000/mcp",
):
pp("Checking auth..", end=" ")
a = check_auth(client_id, url)
pp("[green]OK[/green]\nConnecting to server..")
async with mcp_client_session(url, a.access_token) as session:
tools = await session.list_tools()
pp([i.name for i in tools.tools])

pp("[green]OK[/green]\nCalling tool..")
n = int(time.time())
query = f"SELECT {n} as n"
result = await session.call_tool("RunSqlQuery", {"s": query})
result = result.structuredContent["result"]["result"]
pp(result)

query2 = f"""
SELECT query
FROM sys.project.jobs_recent
WHERE query_type = 'REST' and submitted_ts > CURRENT_TIMESTAMP() - INTERVAL '1' minute
and query like '/* dremioai: submitter=RunS%' and query like '%SELECT {n} as n';
"""
result = await session.call_tool("RunSqlQuery", {"s": query2})
result = result.structuredContent["result"]["result"]
pp(result)

if len(result) != 1:
pp("[red]FAIL[/red]")
pp("[green]OK[/green]")


# Add the CLI subcommand to the main app
app.add_typer(cli)
app.add_typer(auth)


if __name__ == "__main__":
log.configure(enable_json_logging=False, to_file=False)
app()