Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/dremioai/api/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from typer import Typer, Option, BadParameter, Argument
from rich import print_json as pj, print as pp
Expand All @@ -26,6 +26,7 @@
from dremioai.api.cli.engines import app as engines_app
from dremioai.api.cli.prometheus import app as prometheus_app
from dremioai.api.cli.search import app as search_app
from dremioai.api.cli.oauth import app as oauth_app


def common_args(
Expand All @@ -49,6 +50,7 @@ def common_args(
app.add_typer(engines_app, callback=common_args)
app.add_typer(prometheus_app, callback=common_args)
app.add_typer(search_app, callback=common_args)
app.add_typer(oauth_app, callback=common_args)


@catalog_app.command(name="lineage")
Expand Down
72 changes: 72 additions & 0 deletions src/dremioai/api/cli/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dremioai.api.oauth2 import get_oauth2_tokens
from dremioai.config import settings
from typing import Annotated
from typer import Option, Typer
from rich import print as pp

app = Typer(
no_args_is_help=True,
name="oauth",
help="Run commands related to oauth",
context_settings=dict(help_option_names=["-h", "--help"]),
)


@app.command("login")
def login(
client_id: Annotated[str, Option(help="The client id for the OAuth app")] = None,
):
if not settings.instance().dremio.oauth_supported:
raise RuntimeError("OAuth is not supported for this Dremio instance")

if client_id is not None:
if settings.instance().dremio.oauth_configured:
settings.instance().dremio.oauth2.client_id = client_id
else:
settings.instance().dremio.oauth2 = settings.OAuth2.model_validate(
{"client_id": client_id}
)
oauth = get_oauth2_tokens()
oauth.update_settings()
pp(
settings.instance().dremio.model_dump(
exclude_none=True, mode="json", by_alias=True, exclude_unset=True
)
)


@app.command("status")
def status():
if not settings.instance().dremio.oauth_supported:
pp(
f"OAuth is supported only for this Dremio cloud (uri={settings.instance().dremio.uri})"
)
return

if not settings.instance().dremio.oauth_configured:
pp("OAuth is not configured for this Dremio instance")
return

tok = (
f"{settings.instance().dremio.pat[:4]}..."
if settings.instance().dremio.pat
else "<not set>"
)
exp = (
str(settings.instance().dremio.oauth2.expiry)
if settings.instance().dremio.oauth2.expiry
else ""
)
if settings.instance().dremio.oauth2.has_expired:
exp += f":(EXPIRED)"
pp(
{
"token": tok,
"expiry": exp,
"user": (
settings.instance().dremio.oauth2.dremio_user_identifier
if settings.instance().dremio.oauth2.dremio_user_identifier
else ""
),
}
)
154 changes: 154 additions & 0 deletions src/dremioai/api/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from aiohttp import ClientSession, web
from threading import Thread
from secrets import token_urlsafe
from hashlib import sha256
from base64 import urlsafe_b64encode
from urllib.parse import urlencode, urlparse
from dremioai.config import settings
from datetime import datetime, timedelta
from importlib.resources import files
import webbrowser
import asyncio


class OAuth2Redirect:
def __init__(
self, client_id, code_verifier, code_challenge, token_url, redirect_port
):
self.client_id = client_id
self.code_verifier = code_verifier
self.code_challenge = code_challenge
self.token_url = token_url
self.redirect_port = redirect_port
self.stop = asyncio.Event()
self.token = {}

async def auth_redirect(self, request: web.Request):
print(f"auth_redirect: {request}")
redirect_uri = f"http://localhost:{self.redirect_port}"
params = {
"client_id": self.client_id,
"code_verifier": self.code_verifier,
"code": request.query["code"],
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
}

async with ClientSession() as session:
async with session.post(self.token_url, data=params) as response:
if response.status != 200:
print(f"Failed to get token: {await response.text()}")
else:
self.token = await response.json()
self.stop.set()
auth_html = files("dremioai.resources") / "auth_redirect.html"
return web.Response(text=auth_html.read_text(), content_type="text/html")

@property
def access_token(self) -> str:
return self.token.get("access_token")

@property
def refresh_token(self) -> str:
return self.token.get("refresh_token")

@property
def user(self) -> str:
return self.token.get("dremio_user_identifier")

@property
def expiry(self) -> int:
return self.token.get("expires_in")

async def start_server(self):
app = web.Application()
app.router.add_get("/", self.auth_redirect)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", self.redirect_port)
await site.start()
await self.stop.wait()

def update_settings(self):
expiry = datetime.now() + timedelta(seconds=self.expiry - 10)
settings.instance().dremio.pat = self.access_token
settings.instance().dremio.oauth2 = settings.OAuth2.model_validate(
{
"client_id": self.client_id,
"refresh_token": self.refresh_token,
"dremio_user_identifier": self.user,
"expiry": expiry,
}
)
settings.write_settings()


def run_server(oauth: OAuth2Redirect):
print("Starting server")
asyncio.run(oauth.start_server())


def get_pkce_pair(length=96):
length = max(min(length, 128), 43)
code_verifier = token_urlsafe(length)
code_challenge = (
urlsafe_b64encode(sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
)
return code_verifier, code_challenge


class OAuth2:
def __init__(self):
if settings.instance().dremio.oauth2.client_id is None:
raise RuntimeError("oauth_client_id is not set in the config file")

base = urlparse(settings.instance().dremio.uri)
if base.netloc.startswith("api."):
base = base._replace(netloc=f"login.{base.netloc[4:]}")
url = base.geturl()

self.client_id = settings.instance().dremio.oauth2.client_id
self.authorize_url = f"{url}/oauth/authorize"
self.access_token_url = f"{url}/oauth/token"
self.redirect_port = 8976
self.scope = "dremio.all offline_access"
self.code_verifier, self.code_challenge = get_pkce_pair()
self.init_params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": f"http://localhost:{self.redirect_port}",
"scope": self.scope,
"code_challenge": self.code_challenge,
"code_challenge_method": "S256",
}
print(self.init_params)
self.oauth_redirect = OAuth2Redirect(
self.client_id,
self.code_verifier,
self.code_challenge,
self.access_token_url,
self.redirect_port,
)


def get_oauth2_tokens() -> OAuth2Redirect:
# client_id = "311658a1-19ae-4851-b6a6-911c794312e2",
# client_id = "a3743893-d849-4c8a-893b-533dd457aac4"
oauth = OAuth2()
server_thread = Thread(
target=run_server,
daemon=True,
args=(oauth.oauth_redirect,),
)
server_thread.start()

url = f"{oauth.authorize_url}?{urlencode(oauth.init_params)}"
print(f"Opening browser to {url}")
webbrowser.open(url)
server_thread.join()
print(
f"Access token: {oauth.oauth_redirect.access_token}\n"
f"Refresh token: {oauth.oauth_redirect.refresh_token}\n"
f"Expiry: {oauth.oauth_redirect.expiry}\n"
)
return oauth.oauth_redirect
27 changes: 18 additions & 9 deletions src/dremioai/api/transport.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from aiohttp import ClientSession, ClientResponse, ClientResponseError
from pathlib import Path
Expand All @@ -22,6 +22,7 @@
from pydantic import BaseModel, ValidationError

from dremioai.config import settings
from dremioai.api.oauth2 import get_oauth2_tokens

DeserializationStrategy: TypeAlias = Union[Callable, BaseModel]

Expand Down Expand Up @@ -125,12 +126,20 @@ async def post(

class DremioAsyncHttpClient(AsyncHttpClient):
def __init__(self, uri: Optional[str] = None, pat: Optional[str] = None):
dremio = settings.instance().dremio
if (
dremio.oauth_supported
and dremio.oauth_configured
and (dremio.oauth2.has_expired or dremio.pat is None)
):
oauth = get_oauth2_tokens()
oauth.update_settings()

if uri is None:
uri = settings.instance().dremio.uri
uri = dremio.uri
if pat is None:
pat = settings.instance().dremio.pat
if pat.startswith("@"):
pat = Path(pat[1:]).expanduser().read_text().strip()
pat = dremio.pat

if uri is None or pat is None:
raise RuntimeError(f"uri={uri} pat={pat} are required")
super().__init__(uri, pat)
Loading