Skip to content

Commit 5df65ed

Browse files
Merge pull request #16 from dremio/dc-oauth-support
Dc oauth support
2 parents c27a792 + e9b7693 commit 5df65ed

8 files changed

Lines changed: 380 additions & 45 deletions

File tree

src/dremioai/api/cli/__main__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
#
1+
#
22
# Copyright (C) 2017-2025 Dremio Corporation
3-
#
3+
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at
7-
#
7+
#
88
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
9+
#
1010
# Unless required by applicable law or agreed to in writing, software
1111
# distributed under the License is distributed on an "AS IS" BASIS,
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
#
15+
#
1616

1717
from typer import Typer, Option, BadParameter, Argument
1818
from rich import print_json as pj, print as pp
@@ -26,6 +26,7 @@
2626
from dremioai.api.cli.engines import app as engines_app
2727
from dremioai.api.cli.prometheus import app as prometheus_app
2828
from dremioai.api.cli.search import app as search_app
29+
from dremioai.api.cli.oauth import app as oauth_app
2930

3031

3132
def common_args(
@@ -49,6 +50,7 @@ def common_args(
4950
app.add_typer(engines_app, callback=common_args)
5051
app.add_typer(prometheus_app, callback=common_args)
5152
app.add_typer(search_app, callback=common_args)
53+
app.add_typer(oauth_app, callback=common_args)
5254

5355

5456
@catalog_app.command(name="lineage")

src/dremioai/api/cli/oauth.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from dremioai.api.oauth2 import get_oauth2_tokens
2+
from dremioai.config import settings
3+
from typing import Annotated
4+
from typer import Option, Typer
5+
from rich import print as pp
6+
7+
app = Typer(
8+
no_args_is_help=True,
9+
name="oauth",
10+
help="Run commands related to oauth",
11+
context_settings=dict(help_option_names=["-h", "--help"]),
12+
)
13+
14+
15+
@app.command("login")
16+
def login(
17+
client_id: Annotated[str, Option(help="The client id for the OAuth app")] = None,
18+
):
19+
if not settings.instance().dremio.oauth_supported:
20+
raise RuntimeError("OAuth is not supported for this Dremio instance")
21+
22+
if client_id is not None:
23+
if settings.instance().dremio.oauth_configured:
24+
settings.instance().dremio.oauth2.client_id = client_id
25+
else:
26+
settings.instance().dremio.oauth2 = settings.OAuth2.model_validate(
27+
{"client_id": client_id}
28+
)
29+
oauth = get_oauth2_tokens()
30+
oauth.update_settings()
31+
pp(
32+
settings.instance().dremio.model_dump(
33+
exclude_none=True, mode="json", by_alias=True, exclude_unset=True
34+
)
35+
)
36+
37+
38+
@app.command("status")
39+
def status():
40+
if not settings.instance().dremio.oauth_supported:
41+
pp(
42+
f"OAuth is supported only for this Dremio cloud (uri={settings.instance().dremio.uri})"
43+
)
44+
return
45+
46+
if not settings.instance().dremio.oauth_configured:
47+
pp("OAuth is not configured for this Dremio instance")
48+
return
49+
50+
tok = (
51+
f"{settings.instance().dremio.pat[:4]}..."
52+
if settings.instance().dremio.pat
53+
else "<not set>"
54+
)
55+
exp = (
56+
str(settings.instance().dremio.oauth2.expiry)
57+
if settings.instance().dremio.oauth2.expiry
58+
else ""
59+
)
60+
if settings.instance().dremio.oauth2.has_expired:
61+
exp += f":(EXPIRED)"
62+
pp(
63+
{
64+
"token": tok,
65+
"expiry": exp,
66+
"user": (
67+
settings.instance().dremio.oauth2.dremio_user_identifier
68+
if settings.instance().dremio.oauth2.dremio_user_identifier
69+
else ""
70+
),
71+
}
72+
)

src/dremioai/api/oauth2.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from aiohttp import ClientSession, web
2+
from threading import Thread
3+
from secrets import token_urlsafe
4+
from hashlib import sha256
5+
from base64 import urlsafe_b64encode
6+
from urllib.parse import urlencode, urlparse
7+
from dremioai.config import settings
8+
from datetime import datetime, timedelta
9+
from importlib.resources import files
10+
import webbrowser
11+
import asyncio
12+
13+
14+
class OAuth2Redirect:
15+
def __init__(
16+
self, client_id, code_verifier, code_challenge, token_url, redirect_port
17+
):
18+
self.client_id = client_id
19+
self.code_verifier = code_verifier
20+
self.code_challenge = code_challenge
21+
self.token_url = token_url
22+
self.redirect_port = redirect_port
23+
self.stop = asyncio.Event()
24+
self.token = {}
25+
26+
async def auth_redirect(self, request: web.Request):
27+
print(f"auth_redirect: {request}")
28+
redirect_uri = f"http://localhost:{self.redirect_port}"
29+
params = {
30+
"client_id": self.client_id,
31+
"code_verifier": self.code_verifier,
32+
"code": request.query["code"],
33+
"grant_type": "authorization_code",
34+
"redirect_uri": redirect_uri,
35+
}
36+
37+
async with ClientSession() as session:
38+
async with session.post(self.token_url, data=params) as response:
39+
if response.status != 200:
40+
print(f"Failed to get token: {await response.text()}")
41+
else:
42+
self.token = await response.json()
43+
self.stop.set()
44+
auth_html = files("dremioai.resources") / "auth_redirect.html"
45+
return web.Response(text=auth_html.read_text(), content_type="text/html")
46+
47+
@property
48+
def access_token(self) -> str:
49+
return self.token.get("access_token")
50+
51+
@property
52+
def refresh_token(self) -> str:
53+
return self.token.get("refresh_token")
54+
55+
@property
56+
def user(self) -> str:
57+
return self.token.get("dremio_user_identifier")
58+
59+
@property
60+
def expiry(self) -> int:
61+
return self.token.get("expires_in")
62+
63+
async def start_server(self):
64+
app = web.Application()
65+
app.router.add_get("/", self.auth_redirect)
66+
runner = web.AppRunner(app)
67+
await runner.setup()
68+
site = web.TCPSite(runner, "localhost", self.redirect_port)
69+
await site.start()
70+
await self.stop.wait()
71+
72+
def update_settings(self):
73+
expiry = datetime.now() + timedelta(seconds=self.expiry - 10)
74+
settings.instance().dremio.pat = self.access_token
75+
settings.instance().dremio.oauth2 = settings.OAuth2.model_validate(
76+
{
77+
"client_id": self.client_id,
78+
"refresh_token": self.refresh_token,
79+
"dremio_user_identifier": self.user,
80+
"expiry": expiry,
81+
}
82+
)
83+
settings.write_settings()
84+
85+
86+
def run_server(oauth: OAuth2Redirect):
87+
print("Starting server")
88+
asyncio.run(oauth.start_server())
89+
90+
91+
def get_pkce_pair(length=96):
92+
length = max(min(length, 128), 43)
93+
code_verifier = token_urlsafe(length)
94+
code_challenge = (
95+
urlsafe_b64encode(sha256(code_verifier.encode()).digest()).rstrip(b"=").decode()
96+
)
97+
return code_verifier, code_challenge
98+
99+
100+
class OAuth2:
101+
def __init__(self):
102+
if settings.instance().dremio.oauth2.client_id is None:
103+
raise RuntimeError("oauth_client_id is not set in the config file")
104+
105+
base = urlparse(settings.instance().dremio.uri)
106+
if base.netloc.startswith("api."):
107+
base = base._replace(netloc=f"login.{base.netloc[4:]}")
108+
url = base.geturl()
109+
110+
self.client_id = settings.instance().dremio.oauth2.client_id
111+
self.authorize_url = f"{url}/oauth/authorize"
112+
self.access_token_url = f"{url}/oauth/token"
113+
self.redirect_port = 8976
114+
self.scope = "dremio.all offline_access"
115+
self.code_verifier, self.code_challenge = get_pkce_pair()
116+
self.init_params = {
117+
"client_id": self.client_id,
118+
"response_type": "code",
119+
"redirect_uri": f"http://localhost:{self.redirect_port}",
120+
"scope": self.scope,
121+
"code_challenge": self.code_challenge,
122+
"code_challenge_method": "S256",
123+
}
124+
print(self.init_params)
125+
self.oauth_redirect = OAuth2Redirect(
126+
self.client_id,
127+
self.code_verifier,
128+
self.code_challenge,
129+
self.access_token_url,
130+
self.redirect_port,
131+
)
132+
133+
134+
def get_oauth2_tokens() -> OAuth2Redirect:
135+
# client_id = "311658a1-19ae-4851-b6a6-911c794312e2",
136+
# client_id = "a3743893-d849-4c8a-893b-533dd457aac4"
137+
oauth = OAuth2()
138+
server_thread = Thread(
139+
target=run_server,
140+
daemon=True,
141+
args=(oauth.oauth_redirect,),
142+
)
143+
server_thread.start()
144+
145+
url = f"{oauth.authorize_url}?{urlencode(oauth.init_params)}"
146+
print(f"Opening browser to {url}")
147+
webbrowser.open(url)
148+
server_thread.join()
149+
print(
150+
f"Access token: {oauth.oauth_redirect.access_token}\n"
151+
f"Refresh token: {oauth.oauth_redirect.refresh_token}\n"
152+
f"Expiry: {oauth.oauth_redirect.expiry}\n"
153+
)
154+
return oauth.oauth_redirect

src/dremioai/api/transport.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
#
1+
#
22
# Copyright (C) 2017-2025 Dremio Corporation
3-
#
3+
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
66
# You may obtain a copy of the License at
7-
#
7+
#
88
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
9+
#
1010
# Unless required by applicable law or agreed to in writing, software
1111
# distributed under the License is distributed on an "AS IS" BASIS,
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
#
15+
#
1616

1717
from aiohttp import ClientSession, ClientResponse, ClientResponseError
1818
from pathlib import Path
@@ -22,6 +22,7 @@
2222
from pydantic import BaseModel, ValidationError
2323

2424
from dremioai.config import settings
25+
from dremioai.api.oauth2 import get_oauth2_tokens
2526

2627
DeserializationStrategy: TypeAlias = Union[Callable, BaseModel]
2728

@@ -125,12 +126,20 @@ async def post(
125126

126127
class DremioAsyncHttpClient(AsyncHttpClient):
127128
def __init__(self, uri: Optional[str] = None, pat: Optional[str] = None):
129+
dremio = settings.instance().dremio
130+
if (
131+
dremio.oauth_supported
132+
and dremio.oauth_configured
133+
and (dremio.oauth2.has_expired or dremio.pat is None)
134+
):
135+
oauth = get_oauth2_tokens()
136+
oauth.update_settings()
137+
128138
if uri is None:
129-
uri = settings.instance().dremio.uri
139+
uri = dremio.uri
130140
if pat is None:
131-
pat = settings.instance().dremio.pat
132-
if pat.startswith("@"):
133-
pat = Path(pat[1:]).expanduser().read_text().strip()
141+
pat = dremio.pat
142+
134143
if uri is None or pat is None:
135144
raise RuntimeError(f"uri={uri} pat={pat} are required")
136145
super().__init__(uri, pat)

0 commit comments

Comments
 (0)