From 14e57204e4b9bd14c555a18966e53c13a3bde03b Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Wed, 7 May 2025 23:15:30 -0400 Subject: [PATCH 1/3] Support for Oauth login into DC --- src/dremioai/api/cli/__main__.py | 12 +- src/dremioai/api/cli/oauth.py | 72 ++++++++++ src/dremioai/api/oauth2.py | 154 ++++++++++++++++++++++ src/dremioai/api/transport.py | 27 ++-- src/dremioai/config/settings.py | 61 ++++++++- src/dremioai/resources/auth_redirect.html | 40 ++++++ src/dremioai/servers/mcp.py | 52 ++++---- 7 files changed, 377 insertions(+), 41 deletions(-) create mode 100644 src/dremioai/api/cli/oauth.py create mode 100644 src/dremioai/api/oauth2.py create mode 100644 src/dremioai/resources/auth_redirect.html diff --git a/src/dremioai/api/cli/__main__.py b/src/dremioai/api/cli/__main__.py index c92c73a..9fa4dfc 100644 --- a/src/dremioai/api/cli/__main__.py +++ b/src/dremioai/api/cli/__main__.py @@ -1,18 +1,18 @@ -# +# # Copyright (C) 2017-2025 Dremio Corporation -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# +# from typer import Typer, Option, BadParameter, Argument from rich import print_json as pj, print as pp @@ -26,6 +26,7 @@ from dremioai.api.cli.engines import app as engines_app from dremioai.api.cli.prometheus import app as prometheus_app from dremioai.api.cli.search import app as search_app +from dremioai.api.cli.oauth import app as oauth_app def common_args( @@ -49,6 +50,7 @@ def common_args( app.add_typer(engines_app, callback=common_args) app.add_typer(prometheus_app, callback=common_args) app.add_typer(search_app, callback=common_args) +app.add_typer(oauth_app, callback=common_args) @catalog_app.command(name="lineage") diff --git a/src/dremioai/api/cli/oauth.py b/src/dremioai/api/cli/oauth.py new file mode 100644 index 0000000..f846885 --- /dev/null +++ b/src/dremioai/api/cli/oauth.py @@ -0,0 +1,72 @@ +from dremioai.api.oauth2 import get_oauth2_tokens +from dremioai.config import settings +from typing import Annotated +from typer import Option, Typer +from rich import print as pp + +app = Typer( + no_args_is_help=True, + name="oauth", + help="Run commands related to oauth", + context_settings=dict(help_option_names=["-h", "--help"]), +) + + +@app.command("login") +def login( + client_id: Annotated[str, Option(help="The client id for the OAuth app")] = None, +): + if not settings.instance().dremio.oauth_supported: + raise RuntimeError("OAuth is not supported for this Dremio instance") + + if client_id is not None: + if settings.instance().dremio.oauth_configured: + settings.instance().dremio.oauth2.client_id = client_id + else: + settings.instance().dremio.oauth2 = settings.OAuth2.model_validate( + {"client_id": client_id} + ) + oauth = get_oauth2_tokens() + oauth.update_settings() + pp( + settings.instance().dremio.model_dump( + exclude_none=True, mode="json", by_alias=True, exclude_unset=True + ) + ) + + +@app.command("status") +def status(): + if not settings.instance().dremio.oauth_supported: + pp( + f"OAuth is supported only for this Dremio cloud (uri={settings.instance().dremio.uri})" + ) + return + + if not settings.instance().dremio.oauth_configured: + pp("OAuth is not configured for this Dremio instance") + return + + tok = ( + f"{settings.instance().dremio.pat[:4]}..." + if settings.instance().dremio.pat + else "" + ) + exp = ( + str(settings.instance().dremio.oauth2.expiry) + if settings.instance().dremio.oauth2.expiry + else "" + ) + if settings.instance().dremio.oauth2.has_expired: + exp += f":(EXPIRED)" + pp( + { + "token": tok, + "expiry": exp, + "user": ( + settings.instance().dremio.oauth2.dremio_user_identifier + if settings.instance().dremio.oauth2.dremio_user_identifier + else "" + ), + } + ) diff --git a/src/dremioai/api/oauth2.py b/src/dremioai/api/oauth2.py new file mode 100644 index 0000000..faca23d --- /dev/null +++ b/src/dremioai/api/oauth2.py @@ -0,0 +1,154 @@ +from aiohttp import ClientSession, web +from threading import Thread +from secrets import token_urlsafe +from hashlib import sha256 +from base64 import urlsafe_b64encode +from urllib.parse import urlencode, urlparse +from dremioai.config import settings +from datetime import datetime, timedelta +from importlib.resources import files +import webbrowser +import asyncio + + +class OAuth2Redirect: + def __init__( + self, client_id, code_verifier, code_challenge, token_url, redirect_port + ): + self.client_id = client_id + self.code_verifier = code_verifier + self.code_challenge = code_challenge + self.token_url = token_url + self.redirect_port = redirect_port + self.stop = asyncio.Event() + self.token = {} + + async def auth_redirect(self, request: web.Request): + print(f"auth_redirect: {request}") + redirect_uri = f"http://localhost:{self.redirect_port}" + params = { + "client_id": self.client_id, + "code_verifier": self.code_verifier, + "code": request.query["code"], + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + } + + async with ClientSession() as session: + async with session.post(self.token_url, data=params) as response: + if response.status != 200: + print(f"Failed to get token: {await response.text()}") + else: + self.token = await response.json() + self.stop.set() + auth_html = files("dremioai.resources") / "auth_redirect.html" + return web.Response(text=auth_html.read_text(), content_type="text/html") + + @property + def access_token(self) -> str: + return self.token.get("access_token") + + @property + def refresh_token(self) -> str: + return self.token.get("refresh_token") + + @property + def user(self) -> str: + return self.token.get("dremio_user_identifier") + + @property + def expiry(self) -> int: + return self.token.get("expires_in") + + async def start_server(self): + app = web.Application() + app.router.add_get("/", self.auth_redirect) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", self.redirect_port) + await site.start() + await self.stop.wait() + + def update_settings(self): + expiry = datetime.now() + timedelta(seconds=self.expiry - 10) + settings.instance().dremio.pat = self.access_token + settings.instance().dremio.oauth2 = settings.OAuth2.model_validate( + { + "client_id": self.client_id, + "refresh_token": self.refresh_token, + "dremio_user_identifier": self.user, + "expiry": expiry, + } + ) + settings.write_settings() + + +def run_server(oauth: OAuth2Redirect): + print("Starting server") + asyncio.run(oauth.start_server()) + + +def get_pkce_pair(length=96): + length = max(min(length, 128), 43) + code_verifier = token_urlsafe(length) + code_challenge = ( + urlsafe_b64encode(sha256(code_verifier.encode()).digest()).rstrip(b"=").decode() + ) + return code_verifier, code_challenge + + +class OAuth2: + def __init__(self): + if settings.instance().dremio.oauth2.client_id is None: + raise RuntimeError("oauth_client_id is not set in the config file") + + base = urlparse(settings.instance().dremio.uri) + if base.netloc.startswith("api."): + base = base._replace(netloc=f"login.{base.netloc[4:]}") + url = base.geturl() + + self.client_id = settings.instance().dremio.oauth2.client_id + self.authorize_url = f"{url}/oauth/authorize" + self.access_token_url = f"{url}/oauth/token" + self.redirect_port = 8976 + self.scope = "dremio.all offline_access" + self.code_verifier, self.code_challenge = get_pkce_pair() + self.init_params = { + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": f"http://localhost:{self.redirect_port}", + "scope": self.scope, + "code_challenge": self.code_challenge, + "code_challenge_method": "S256", + } + print(self.init_params) + self.oauth_redirect = OAuth2Redirect( + self.client_id, + self.code_verifier, + self.code_challenge, + self.access_token_url, + self.redirect_port, + ) + + +def get_oauth2_tokens() -> OAuth2Redirect: + # client_id = "311658a1-19ae-4851-b6a6-911c794312e2", + # client_id = "a3743893-d849-4c8a-893b-533dd457aac4" + oauth = OAuth2() + server_thread = Thread( + target=run_server, + daemon=True, + args=(oauth.oauth_redirect,), + ) + server_thread.start() + + url = f"{oauth.authorize_url}?{urlencode(oauth.init_params)}" + print(f"Opening browser to {url}") + webbrowser.open(url) + server_thread.join() + print( + f"Access token: {oauth.oauth_redirect.access_token}\n" + f"Refresh token: {oauth.oauth_redirect.refresh_token}\n" + f"Expiry: {oauth.oauth_redirect.expiry}\n" + ) + return oauth.oauth_redirect diff --git a/src/dremioai/api/transport.py b/src/dremioai/api/transport.py index bee00b5..b3597d3 100644 --- a/src/dremioai/api/transport.py +++ b/src/dremioai/api/transport.py @@ -1,18 +1,18 @@ -# +# # Copyright (C) 2017-2025 Dremio Corporation -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# +# from aiohttp import ClientSession, ClientResponse, ClientResponseError from pathlib import Path @@ -22,6 +22,7 @@ from pydantic import BaseModel, ValidationError from dremioai.config import settings +from dremioai.api.oauth2 import get_oauth2_tokens DeserializationStrategy: TypeAlias = Union[Callable, BaseModel] @@ -125,12 +126,20 @@ async def post( class DremioAsyncHttpClient(AsyncHttpClient): def __init__(self, uri: Optional[str] = None, pat: Optional[str] = None): + dremio = settings.instance().dremio + if ( + dremio.oauth_supported + and dremio.oauth_configured + and (dremio.oauth2.has_expired or dremio.pat is None) + ): + oauth = get_oauth2_tokens() + oauth.update_settings() + if uri is None: - uri = settings.instance().dremio.uri + uri = dremio.uri if pat is None: - pat = settings.instance().dremio.pat - if pat.startswith("@"): - pat = Path(pat[1:]).expanduser().read_text().strip() + pat = dremio.pat + if uri is None or pat is None: raise RuntimeError(f"uri={uri} pat={pat} are required") super().__init__(uri, pat) diff --git a/src/dremioai/config/settings.py b/src/dremioai/config/settings.py index 89bc9c1..8d7a385 100644 --- a/src/dremioai/config/settings.py +++ b/src/dremioai/config/settings.py @@ -14,7 +14,14 @@ # limitations under the License. # -from pydantic import Field, HttpUrl, AfterValidator, BaseModel, ConfigDict +from pydantic import ( + Field, + HttpUrl, + AfterValidator, + BaseModel, + ConfigDict, + field_serializer, +) from pydantic_settings import BaseSettings, SettingsConfigDict from typing import Optional, Union, Annotated, Self, List, Dict, Any, Callable from dremioai.config.tools import ToolType @@ -27,6 +34,7 @@ from contextvars import ContextVar, copy_context from os import environ from importlib.util import find_spec +from datetime import datetime def _resolve_tools_settings(server_mode: Union[ToolType, int, str]) -> ToolType: @@ -50,6 +58,10 @@ class Tools(BaseModel): ] = Field(default=ToolType.FOR_SELF) model_config = ConfigDict(validate_assignment=True, use_enum_values=True) + @field_serializer("server_mode") + def serialize_server_mode(self, server_mode: ToolType): + return ",".join(m.name for m in ToolType if m & server_mode) + class DremioCloudUri(StrEnum): PROD = auto() @@ -90,15 +102,60 @@ class Model(StrEnum): openai = auto() +class OAuth2(BaseModel): + client_id: str + refresh_token: Optional[str] = None + dremio_user_identifier: Optional[str] = None + expiry: Optional[datetime] = None + model_config = ConfigDict(validate_assignment=True) + + @property + def has_expired(self) -> bool: + return self.expiry is not None and self.expiry < datetime.now() + + class Dremio(BaseModel): uri: Annotated[ Union[str, HttpUrl, DremioCloudUri], AfterValidator(_resolve_dremio_uri) ] - pat: Annotated[str, AfterValidator(_resolve_token_file)] + raw_pat: Optional[str] = Field(default=None, alias="pat") project_id: Optional[str] = None enable_experimental: Optional[bool] = False # enable experimental tools + oauth2: Optional[OAuth2] = None model_config = ConfigDict(validate_assignment=True) + @field_serializer("raw_pat") + def serialize_pat(self, pat: str): + return self.raw_pat if pat != self.raw_pat else pat + + @property + def oauth_configured(self) -> bool: + return self.oauth2 is not None + + @property + def oauth_supported(self) -> bool: + return self.project_id is not None + + # @field_validator("_pat", mode="wrap") + # @classmethod + # def validate_pat(cls, v: str, handler: ValidatorFunctionWrapHandler) -> str: + # v = _resolve_token_file(v) + # return handler(v) + + @property + def pat(self) -> str: + if v := getattr(self, "_pat_resolved", None): + return v + if self.raw_pat is not None and self.raw_pat.startswith("@"): + self._pat_resolved = _resolve_token_file(self.raw_pat) + return self._pat_resolved + return self.raw_pat + + @pat.setter + def pat(self, v: str): + self.raw_pat = v + self._pat_resolved = None + class OpenAi(BaseModel): api_key: Annotated[str, AfterValidator(_resolve_token_file)] = None diff --git a/src/dremioai/resources/auth_redirect.html b/src/dremioai/resources/auth_redirect.html new file mode 100644 index 0000000..9b0b23f --- /dev/null +++ b/src/dremioai/resources/auth_redirect.html @@ -0,0 +1,40 @@ + + + + + Authentication Complete + + + + + +

Authentication Successful

+

You may now close this window.

+ Click here to close + + \ No newline at end of file diff --git a/src/dremioai/servers/mcp.py b/src/dremioai/servers/mcp.py index 06bda70..eb7b612 100644 --- a/src/dremioai/servers/mcp.py +++ b/src/dremioai/servers/mcp.py @@ -30,6 +30,7 @@ from rich import console, table, print as pp from click import Choice from dremioai.config import settings +from dremioai.api.oauth2 import get_oauth2_tokens from enum import StrEnum, auto from json import load, dump as jdump from shutil import which @@ -135,6 +136,15 @@ def main( print(tool.__name__) return + dremio = settings.instance().dremio + if ( + dremio.oauth_supported + and dremio.oauth_configured + and (dremio.oauth2.has_expired or dremio.pat is None) + ): + oauth = get_oauth2_tokens() + oauth.update_settings() + app = init( uri=cfg.dremio.uri, pat=cfg.dremio.pat, @@ -188,7 +198,10 @@ def show_default_config( pp( dump( settings.instance().model_dump( - exclude_none=True, mode="json", exclude_unset=True + exclude_none=True, + mode="json", + exclude_unset=True, + by_alias=True, ) ) ) @@ -270,6 +283,10 @@ def create_default_config( enable_experimental: Annotated[ bool, Option(help="Enable experimental features") ] = False, + oauth_client_id: Annotated[ + Optional[str], + Option(help="The ID of OAuth application, for OAuth2 logon support"), + ] = None, dry_run: Annotated[ bool, Option(help="Dry run, do not overwrite the config file. Just print it") ] = False, @@ -281,36 +298,21 @@ def create_default_config( "pat": pat, "project_id": project_id, "enable_experimental": enable_experimental, + "oauth": ( + settings.OAuth2.model_validate({"client_id": oauth_client_id}) + if oauth_client_id + else None + ), } ) ts = settings.Tools.model_validate({"server_mode": mode}) settings.configure(settings.default_config(), force=True) settings.instance().dremio = dremio settings.instance().tools = ts - d = settings.instance().model_dump( - exclude_none=True, mode="json", exclude_unset=True - ) - - add_representer( - str, - lambda dumper, data: dumper.represent_scalar( - "tag:yaml.org,2002:str", data, style=('"' if "@" in data else None) - ), - ) - if pat.startswith("@") and d["dremio"]["pat"] != pat: - d["dremio"]["pat"] = pat - d["tools"]["server_mode"] = mode - - if dry_run: - pp(dump(d)) - return - - dc = settings.default_config() - if not dc.exists(): - dc.parent.mkdir(parents=True, exist_ok=True) - with dc.open("w") as f: - dump(d, f) - pp(f"Created default config file: {dc!s}") + if (d := settings.write_settings(dry_run=dry_run)) is not None and dry_run: + pp(d) + elif not dry_run: + pp(f"Created default config file: {settings.default_config()!s}") # -------------------------------------------------------------------------------- From 79bba7e6ba1760b7d379142c40da957457f118cc Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Mon, 12 May 2025 13:23:54 -0400 Subject: [PATCH 2/3] Correcting tool annotations --- src/dremioai/tools/tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dremioai/tools/tools.py b/src/dremioai/tools/tools.py index 260f9d0..adb6f00 100644 --- a/src/dremioai/tools/tools.py +++ b/src/dremioai/tools/tools.py @@ -439,7 +439,7 @@ def system_prompt(): class GetRelevantMetrics(Tools): - For: ClassVar[Annotated[ToolType, ToolType.FOR_PROMETHEUS | ToolType.FOR_SELF]] + For: ClassVar[Annotated[ToolType, ToolType.FOR_PROMETHEUS]] async def invoke(self) -> Dict[str, Any]: """ @@ -462,7 +462,7 @@ async def invoke(self) -> Dict[str, Any]: class GetMetricSchema(Tools): - For: ClassVar[Annotated[ToolType, ToolType.FOR_PROMETHEUS | ToolType.FOR_SELF]] + For: ClassVar[Annotated[ToolType, ToolType.FOR_PROMETHEUS]] async def invoke(self, metric: str) -> Dict[str, Any]: """ @@ -480,7 +480,7 @@ async def invoke(self, metric: str) -> Dict[str, Any]: class RunPromQL(Tools): - For: ClassVar[Annotated[ToolType, ToolType.FOR_PROMETHEUS | ToolType.FOR_SELF]] + For: ClassVar[Annotated[ToolType, ToolType.FOR_PROMETHEUS]] async def invoke(self, promql_query: str) -> Dict[str, Any]: """ From e9b76936c972e7499d38f6c4466cc39953a138ee Mon Sep 17 00:00:00 2001 From: Aniket Kulkarni Date: Mon, 12 May 2025 17:44:16 -0400 Subject: [PATCH 3/3] Removing extra print --- src/dremioai/config/settings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dremioai/config/settings.py b/src/dremioai/config/settings.py index 8d7a385..7a4eaaf 100644 --- a/src/dremioai/config/settings.py +++ b/src/dremioai/config/settings.py @@ -281,7 +281,6 @@ def configure(cfg: Union[str, Path] = None, force=False) -> ContextVar[Settings] cfg = default_config() if not cfg.exists(): - print(f"Creating default config file: {cfg!s}") cfg.parent.mkdir(parents=True, exist_ok=True) cfg.touch()