Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from os import environ
from importlib.util import find_spec
from datetime import datetime
from dremioai import log

ProjectId = Union[UUID, Literal["DREMIO_DYNAMIC"]]

Expand Down Expand Up @@ -144,6 +145,7 @@ class Dremio(BaseModel):
)
oauth2: Optional[OAuth2] = None
allow_dml: Optional[bool] = False
auth_issuer_uri_override: Optional[str] = None
model_config = ConfigDict(validate_assignment=True)

@field_serializer("raw_pat")
Expand Down Expand Up @@ -186,11 +188,14 @@ def is_cloud(self) -> bool:

@property
def auth_issuer_uri(self) -> Optional[str]:
if self.auth_issuer_uri_override is not None:
return self.auth_issuer_uri_override
if self.is_cloud:
uri = urlparse(self.uri)
if uri.netloc.startswith("api."):
uri = uri._replace(netloc=f"login.{uri.netloc[4:]}")
return uri.geturl()
log.logger("settings").error("Oauth not supported for non-cloud instances")
return None

@property
Expand Down
20 changes: 17 additions & 3 deletions tests/config/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,14 @@ def test_env_file(mock_config_dir):


@pytest.mark.parametrize(
"uri,project_id,issuer,error",
"uri,project_id,issuer,error,iss_override",
[
pytest.param(
uri,
project_id,
iss,
project_id is None,
iss_override,
id=f"{label} with {plabel}",
)
for uri, iss, label in (
Expand All @@ -158,10 +159,23 @@ def test_env_file(mock_config_dir):
("DREMIO_DYNAMIC", "dynamic-project-id"),
(str(uuid.uuid4()), "project-id"),
)
for iss_override in (None, "https://my-override")
],
)
def test_auth_urls(uri: str, project_id: str | None, issuer: str, error: bool):
d = settings.Dremio.model_validate({"uri": uri, "project_id": project_id})
def test_auth_urls(
uri: str, project_id: str | None, issuer: str, error: bool, iss_override: str | None
):
d = settings.Dremio.model_validate(
{
"uri": uri,
"project_id": project_id,
"auth_issuer_uri_override": (
iss_override if iss_override and not error else None
),
}
)
if iss_override:
issuer = iss_override
auth = (f"{issuer}/oauth/authorize", f"{issuer}/oauth/token") if not error else None
issuer = issuer if not error else None
assert d.auth_issuer_uri == issuer
Expand Down