Skip to content

Commit 47fea4e

Browse files
Supporting an override for OAuth issuer uri (#60)
1 parent 8ae378d commit 47fea4e

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/dremioai/config/settings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from os import environ
5151
from importlib.util import find_spec
5252
from datetime import datetime
53+
from dremioai import log
5354

5455
ProjectId = Union[UUID, Literal["DREMIO_DYNAMIC"]]
5556

@@ -144,6 +145,7 @@ class Dremio(BaseModel):
144145
)
145146
oauth2: Optional[OAuth2] = None
146147
allow_dml: Optional[bool] = False
148+
auth_issuer_uri_override: Optional[str] = None
147149
model_config = ConfigDict(validate_assignment=True)
148150

149151
@field_serializer("raw_pat")
@@ -186,11 +188,14 @@ def is_cloud(self) -> bool:
186188

187189
@property
188190
def auth_issuer_uri(self) -> Optional[str]:
191+
if self.auth_issuer_uri_override is not None:
192+
return self.auth_issuer_uri_override
189193
if self.is_cloud:
190194
uri = urlparse(self.uri)
191195
if uri.netloc.startswith("api."):
192196
uri = uri._replace(netloc=f"login.{uri.netloc[4:]}")
193197
return uri.geturl()
198+
log.logger("settings").error("Oauth not supported for non-cloud instances")
194199
return None
195200

196201
@property

tests/config/test_settings.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,14 @@ def test_env_file(mock_config_dir):
134134

135135

136136
@pytest.mark.parametrize(
137-
"uri,project_id,issuer,error",
137+
"uri,project_id,issuer,error,iss_override",
138138
[
139139
pytest.param(
140140
uri,
141141
project_id,
142142
iss,
143143
project_id is None,
144+
iss_override,
144145
id=f"{label} with {plabel}",
145146
)
146147
for uri, iss, label in (
@@ -158,10 +159,23 @@ def test_env_file(mock_config_dir):
158159
("DREMIO_DYNAMIC", "dynamic-project-id"),
159160
(str(uuid.uuid4()), "project-id"),
160161
)
162+
for iss_override in (None, "https://my-override")
161163
],
162164
)
163-
def test_auth_urls(uri: str, project_id: str | None, issuer: str, error: bool):
164-
d = settings.Dremio.model_validate({"uri": uri, "project_id": project_id})
165+
def test_auth_urls(
166+
uri: str, project_id: str | None, issuer: str, error: bool, iss_override: str | None
167+
):
168+
d = settings.Dremio.model_validate(
169+
{
170+
"uri": uri,
171+
"project_id": project_id,
172+
"auth_issuer_uri_override": (
173+
iss_override if iss_override and not error else None
174+
),
175+
}
176+
)
177+
if iss_override:
178+
issuer = iss_override
165179
auth = (f"{issuer}/oauth/authorize", f"{issuer}/oauth/token") if not error else None
166180
issuer = issuer if not error else None
167181
assert d.auth_issuer_uri == issuer

0 commit comments

Comments
 (0)