Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def get_notebook_pat_token() -> Optional[str]:
host=cfg.host,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def inner() -> Dict[str, str]:
Expand Down Expand Up @@ -251,6 +252,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
use_header=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def inner() -> Dict[str, str]:
Expand All @@ -277,7 +279,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
elif cfg.azure_client_id:
client_id = cfg.azure_client_id
client_secret = cfg.azure_client_secret
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host)
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds)
if not client_id:
client_id = "databricks-cli"
oidc_endpoints = cfg.databricks_oidc_endpoints
Expand All @@ -301,6 +303,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
redirect_url=redirect_url,
scopes=scopes,
profile=cfg.profile,
timeout=cfg.http_timeout_seconds,
)
credentials = token_cache.load()
if credentials:
Expand All @@ -320,6 +323,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
redirect_url=redirect_url,
client_secret=client_secret,
scopes=scopes,
timeout=cfg.http_timeout_seconds,
)
consent = oauth_client.initiate_consent()
if not consent:
Expand Down Expand Up @@ -367,6 +371,7 @@ def token_source_for(resource: str) -> oauth.TokenSource:
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

_ensure_host_present(cfg, token_source_for)
Expand Down Expand Up @@ -489,6 +494,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
use_params=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -551,7 +557,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
aad_endpoint = cfg.arm_environment.active_directory_endpoint
if not cfg.azure_tenant_id:
# detect Azure AD Tenant ID if it's not specified directly
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host).token_endpoint
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds).token_endpoint
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]

inner = oauth.ClientCredentials(
Expand All @@ -567,6 +573,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down
29 changes: 25 additions & 4 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def retrieve_token(
use_params=False,
use_header=False,
headers=None,
timeout=60,
) -> Token:
logger.debug(f"Retrieving token for {client_id}")
if use_params:
Expand All @@ -206,7 +207,7 @@ def retrieve_token(
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
else:
auth = IgnoreNetrcAuth()
resp = requests.post(token_url, params, auth=auth, headers=headers)
resp = requests.post(token_url, params, auth=auth, headers=headers, timeout=timeout)
if not resp.ok:
if resp.headers["Content-Type"].startswith("application/json"):
err = resp.json()
Expand Down Expand Up @@ -536,6 +537,7 @@ def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _Bas

def get_azure_entra_id_workspace_endpoints(
host: str,
timeout=60,
) -> Optional[OidcEndpoints]:
"""
Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
Expand All @@ -545,7 +547,7 @@ def get_azure_entra_id_workspace_endpoints(
"""
# In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
host = _fix_host_if_needed(host)
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False, timeout=timeout)
real_auth_url = res.headers.get("location")
if not real_auth_url:
return None
Expand All @@ -564,11 +566,13 @@ def __init__(
client_secret: str = None,
redirect_url: str = None,
disable_async: bool = True,
timeout: float = 60,
):
self._token_endpoint = token_endpoint
self._client_id = client_id
self._client_secret = client_secret
self._redirect_url = redirect_url
self._timeout = timeout
super().__init__(
token=token,
disable_async=disable_async,
Expand All @@ -584,13 +588,15 @@ def from_dict(
client_id: str,
client_secret: str = None,
redirect_url: str = None,
timeout: float = 60,
) -> "SessionCredentials":
return SessionCredentials(
token=Token.from_dict(raw["token"]),
token_endpoint=token_endpoint,
client_id=client_id,
client_secret=client_secret,
redirect_url=redirect_url,
timeout=timeout,
)

def auth_type(self):
Expand Down Expand Up @@ -626,6 +632,7 @@ def refresh(self) -> Token:
params=params,
use_params=True,
headers=headers,
timeout=self._timeout,
)


Expand All @@ -639,6 +646,7 @@ def __init__(
token_endpoint: str,
client_id: str,
client_secret: str = None,
timeout: float = 60,
) -> None:
self._verifier = verifier
self._state = state
Expand All @@ -647,6 +655,7 @@ def __init__(
self._token_endpoint = token_endpoint
self._client_id = client_id
self._client_secret = client_secret
self._timeout = timeout

def as_dict(self) -> dict:
return {
Expand All @@ -663,7 +672,7 @@ def authorization_url(self) -> str:
return self._authorization_url

@staticmethod
def from_dict(raw: dict, client_secret: str = None) -> "Consent":
def from_dict(raw: dict, client_secret: str = None, timeout: float = 60) -> "Consent":
return Consent(
raw["state"],
raw["verifier"],
Expand All @@ -672,6 +681,7 @@ def from_dict(raw: dict, client_secret: str = None) -> "Consent":
token_endpoint=raw["token_endpoint"],
client_id=raw["client_id"],
client_secret=client_secret,
timeout=timeout,
)

def launch_external_browser(self) -> SessionCredentials:
Expand Down Expand Up @@ -717,13 +727,15 @@ def exchange(self, code: str, state: str) -> SessionCredentials:
params=params,
headers=headers,
use_params=True,
timeout=self._timeout,
)
return SessionCredentials(
token,
self._token_endpoint,
self._client_id,
self._client_secret,
self._redirect_url,
timeout=self._timeout,
)
except ValueError as e:
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
Expand Down Expand Up @@ -762,6 +774,7 @@ def __init__(
client_id: str,
scopes: List[str] = None,
client_secret: str = None,
timeout: float = 60,
):
if not scopes:
# Default for direct OAuthClient users (e.g., via from_host()).
Expand All @@ -775,6 +788,7 @@ def __init__(
self._client_secret = client_secret
self._oidc_endpoints = oidc_endpoints
self._scopes = scopes
self._timeout = timeout

@staticmethod
def from_host(
Expand Down Expand Up @@ -824,6 +838,7 @@ def initiate_consent(self) -> Consent:
token_endpoint=self._oidc_endpoints.token_endpoint,
client_id=self._client_id,
client_secret=self._client_secret,
timeout=self._timeout,
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -851,6 +866,7 @@ class ClientCredentials(Refreshable):
use_header: bool = False
disable_async: bool = True
authorization_details: str = None
timeout: float = 60

def __post_init__(self):
super().__init__(disable_async=self.disable_async)
Expand All @@ -871,6 +887,7 @@ def refresh(self) -> Token:
params,
use_params=self.use_params,
use_header=self.use_header,
timeout=self.timeout,
)


Expand All @@ -897,6 +914,7 @@ class PATOAuthTokenExchange(Refreshable):
scopes: str
authorization_details: str = None
disable_async: bool = True
timeout: float = 60

def __post_init__(self):
super().__init__(disable_async=self.disable_async)
Expand All @@ -913,7 +931,7 @@ def refresh(self) -> Token:
if self.authorization_details:
params["authorization_details"] = self.authorization_details

resp = requests.post(token_exchange_url, params)
resp = requests.post(token_exchange_url, params, timeout=self.timeout)
if not resp.ok:
if resp.headers["Content-Type"].startswith("application/json"):
err = resp.json()
Expand Down Expand Up @@ -947,6 +965,7 @@ def __init__(
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
profile: Optional[str] = None,
timeout: float = 60,
) -> None:
self._host = host
self._client_id = client_id
Expand All @@ -955,6 +974,7 @@ def __init__(
self._client_secret = client_secret
self._scopes = scopes or []
self._profile = profile
self._timeout = timeout

@property
def filename(self) -> str:
Expand Down Expand Up @@ -985,6 +1005,7 @@ def load(self) -> Optional[SessionCredentials]:
client_id=self._client_id,
client_secret=self._client_secret,
redirect_url=self._redirect_url,
timeout=self._timeout,
)
except Exception:
return None
Expand Down
Loading