Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import logging
import os
import urllib.parse
from abc import ABC, abstractmethod
from datetime import timedelta
Expand Down Expand Up @@ -53,6 +54,10 @@ def __init__(
debug_headers: Optional[bool] = False,
clock: Optional[Clock] = None,
streaming_buffer_size: int = 1024 * 1024,
proxy_url: Optional[str] = None,
proxy_username: Optional[str] = None,
proxy_password: Optional[str] = None,
proxy_auth_type: Optional[str] = None,
): # 1MB
"""
:param debug_truncate_bytes:
Expand Down Expand Up @@ -84,6 +89,45 @@ def __init__(
self._session.auth = self._authenticate
self._streaming_buffer_size = streaming_buffer_size

self._proxy_url = proxy_url or os.environ.get("DATABRICKS_PROXY_URL")
self._proxy_username = proxy_username or os.environ.get("DATABRICKS_PROXY_USERNAME")
self._proxy_password = proxy_password or os.environ.get("DATABRICKS_PROXY_PASSWORD")
self._proxy_auth_type = proxy_auth_type or os.environ.get("DATABRICKS_PROXY_AUTH_TYPE")

if self._proxy_url:
p_url = self._proxy_url
if "://" not in p_url:
p_url = "http://" + p_url

if self._proxy_username:
parsed = urllib.parse.urlparse(p_url)
if "@" in parsed.netloc:
_, host_port = parsed.netloc.rsplit("@", 1)
else:
host_port = parsed.netloc

user = urllib.parse.quote(self._proxy_username)
if self._proxy_password:
password = urllib.parse.quote(self._proxy_password)
user_part = f"{user}:{password}@"
else:
user_part = f"{user}@"

netloc = f"{user_part}{host_port}"
p_url = urllib.parse.urlunparse((
parsed.scheme,
netloc,
parsed.path,
parsed.params,
parsed.query,
parsed.fragment
))

self._session.proxies = {
"http": p_url,
"https": p_url,
}

# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
# @retried for more details.
Expand Down
4 changes: 4 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ class Config:
serverless_compute_id: str = ConfigAttribute(env="DATABRICKS_SERVERLESS_COMPUTE_ID")
skip_verify: bool = ConfigAttribute()
http_timeout_seconds: float = ConfigAttribute()
proxy_url: str = ConfigAttribute(env="DATABRICKS_PROXY_URL")
proxy_username: str = ConfigAttribute(env="DATABRICKS_PROXY_USERNAME")
proxy_password: str = ConfigAttribute(env="DATABRICKS_PROXY_PASSWORD", sensitive=True)
proxy_auth_type: str = ConfigAttribute(env="DATABRICKS_PROXY_AUTH_TYPE")
debug_truncate_bytes: int = ConfigAttribute(env="DATABRICKS_DEBUG_TRUNCATE_BYTES")
debug_headers: bool = ConfigAttribute(env="DATABRICKS_DEBUG_HEADERS")
rate_limit: int = ConfigAttribute(env="DATABRICKS_RATE_LIMIT")
Expand Down
4 changes: 4 additions & 0 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __init__(self, cfg: Config):
http_timeout_seconds=cfg.http_timeout_seconds,
extra_error_customizers=[_AddDebugErrorCustomizer(cfg)],
clock=cfg.clock,
proxy_url=cfg.proxy_url,
proxy_username=cfg.proxy_username,
proxy_password=cfg.proxy_password,
proxy_auth_type=cfg.proxy_auth_type,
)

@property
Expand Down
11 changes: 9 additions & 2 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def get_notebook_pat_token() -> Optional[str]:
host=cfg.host,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def inner() -> Dict[str, str]:
Expand Down Expand Up @@ -251,6 +252,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]:
use_header=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def inner() -> Dict[str, str]:
Expand All @@ -277,7 +279,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
elif cfg.azure_client_id:
client_id = cfg.azure_client_id
client_secret = cfg.azure_client_secret
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host)
oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds)
if not client_id:
client_id = "databricks-cli"
oidc_endpoints = cfg.databricks_oidc_endpoints
Expand All @@ -301,6 +303,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
redirect_url=redirect_url,
scopes=scopes,
profile=cfg.profile,
timeout=cfg.http_timeout_seconds,
)
credentials = token_cache.load()
if credentials:
Expand All @@ -320,6 +323,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]:
redirect_url=redirect_url,
client_secret=client_secret,
scopes=scopes,
timeout=cfg.http_timeout_seconds,
)
consent = oauth_client.initiate_consent()
if not consent:
Expand Down Expand Up @@ -367,6 +371,7 @@ def token_source_for(resource: str) -> oauth.TokenSource:
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

_ensure_host_present(cfg, token_source_for)
Expand Down Expand Up @@ -489,6 +494,7 @@ def token_source_for(audience: str) -> oauth.TokenSource:
use_params=True,
disable_async=cfg.disable_async_token_refresh,
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down Expand Up @@ -551,7 +557,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
aad_endpoint = cfg.arm_environment.active_directory_endpoint
if not cfg.azure_tenant_id:
# detect Azure AD Tenant ID if it's not specified directly
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host).token_endpoint
token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds).token_endpoint
cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0]

inner = oauth.ClientCredentials(
Expand All @@ -567,6 +573,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]:
disable_async=cfg.disable_async_token_refresh,
scopes=cfg.get_scopes_as_string(),
authorization_details=cfg.authorization_details,
timeout=cfg.http_timeout_seconds,
)

def refreshed_headers() -> Dict[str, str]:
Expand Down
29 changes: 25 additions & 4 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def retrieve_token(
use_params=False,
use_header=False,
headers=None,
timeout=60,
) -> Token:
logger.debug(f"Retrieving token for {client_id}")
if use_params:
Expand All @@ -206,7 +207,7 @@ def retrieve_token(
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
else:
auth = IgnoreNetrcAuth()
resp = requests.post(token_url, params, auth=auth, headers=headers)
resp = requests.post(token_url, params, auth=auth, headers=headers, timeout=timeout)
if not resp.ok:
if resp.headers["Content-Type"].startswith("application/json"):
err = resp.json()
Expand Down Expand Up @@ -536,6 +537,7 @@ def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _Bas

def get_azure_entra_id_workspace_endpoints(
host: str,
timeout=60,
) -> Optional[OidcEndpoints]:
"""
Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks
Expand All @@ -545,7 +547,7 @@ def get_azure_entra_id_workspace_endpoints(
"""
# In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint
host = _fix_host_if_needed(host)
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False)
res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False, timeout=timeout)
real_auth_url = res.headers.get("location")
if not real_auth_url:
return None
Expand All @@ -564,11 +566,13 @@ def __init__(
client_secret: str = None,
redirect_url: str = None,
disable_async: bool = True,
timeout: float = 60,
):
self._token_endpoint = token_endpoint
self._client_id = client_id
self._client_secret = client_secret
self._redirect_url = redirect_url
self._timeout = timeout
super().__init__(
token=token,
disable_async=disable_async,
Expand All @@ -584,13 +588,15 @@ def from_dict(
client_id: str,
client_secret: str = None,
redirect_url: str = None,
timeout: float = 60,
) -> "SessionCredentials":
return SessionCredentials(
token=Token.from_dict(raw["token"]),
token_endpoint=token_endpoint,
client_id=client_id,
client_secret=client_secret,
redirect_url=redirect_url,
timeout=timeout,
)

def auth_type(self):
Expand Down Expand Up @@ -626,6 +632,7 @@ def refresh(self) -> Token:
params=params,
use_params=True,
headers=headers,
timeout=self._timeout,
)


Expand All @@ -639,6 +646,7 @@ def __init__(
token_endpoint: str,
client_id: str,
client_secret: str = None,
timeout: float = 60,
) -> None:
self._verifier = verifier
self._state = state
Expand All @@ -647,6 +655,7 @@ def __init__(
self._token_endpoint = token_endpoint
self._client_id = client_id
self._client_secret = client_secret
self._timeout = timeout

def as_dict(self) -> dict:
return {
Expand All @@ -663,7 +672,7 @@ def authorization_url(self) -> str:
return self._authorization_url

@staticmethod
def from_dict(raw: dict, client_secret: str = None) -> "Consent":
def from_dict(raw: dict, client_secret: str = None, timeout: float = 60) -> "Consent":
return Consent(
raw["state"],
raw["verifier"],
Expand All @@ -672,6 +681,7 @@ def from_dict(raw: dict, client_secret: str = None) -> "Consent":
token_endpoint=raw["token_endpoint"],
client_id=raw["client_id"],
client_secret=client_secret,
timeout=timeout,
)

def launch_external_browser(self) -> SessionCredentials:
Expand Down Expand Up @@ -717,13 +727,15 @@ def exchange(self, code: str, state: str) -> SessionCredentials:
params=params,
headers=headers,
use_params=True,
timeout=self._timeout,
)
return SessionCredentials(
token,
self._token_endpoint,
self._client_id,
self._client_secret,
self._redirect_url,
timeout=self._timeout,
)
except ValueError as e:
if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e):
Expand Down Expand Up @@ -762,6 +774,7 @@ def __init__(
client_id: str,
scopes: List[str] = None,
client_secret: str = None,
timeout: float = 60,
):
if not scopes:
# Default for direct OAuthClient users (e.g., via from_host()).
Expand All @@ -775,6 +788,7 @@ def __init__(
self._client_secret = client_secret
self._oidc_endpoints = oidc_endpoints
self._scopes = scopes
self._timeout = timeout

@staticmethod
def from_host(
Expand Down Expand Up @@ -824,6 +838,7 @@ def initiate_consent(self) -> Consent:
token_endpoint=self._oidc_endpoints.token_endpoint,
client_id=self._client_id,
client_secret=self._client_secret,
timeout=self._timeout,
)

def __repr__(self) -> str:
Expand Down Expand Up @@ -851,6 +866,7 @@ class ClientCredentials(Refreshable):
use_header: bool = False
disable_async: bool = True
authorization_details: str = None
timeout: float = 60

def __post_init__(self):
super().__init__(disable_async=self.disable_async)
Expand All @@ -871,6 +887,7 @@ def refresh(self) -> Token:
params,
use_params=self.use_params,
use_header=self.use_header,
timeout=self.timeout,
)


Expand All @@ -897,6 +914,7 @@ class PATOAuthTokenExchange(Refreshable):
scopes: str
authorization_details: str = None
disable_async: bool = True
timeout: float = 60

def __post_init__(self):
super().__init__(disable_async=self.disable_async)
Expand All @@ -913,7 +931,7 @@ def refresh(self) -> Token:
if self.authorization_details:
params["authorization_details"] = self.authorization_details

resp = requests.post(token_exchange_url, params)
resp = requests.post(token_exchange_url, params, timeout=self.timeout)
if not resp.ok:
if resp.headers["Content-Type"].startswith("application/json"):
err = resp.json()
Expand Down Expand Up @@ -947,6 +965,7 @@ def __init__(
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
profile: Optional[str] = None,
timeout: float = 60,
) -> None:
self._host = host
self._client_id = client_id
Expand All @@ -955,6 +974,7 @@ def __init__(
self._client_secret = client_secret
self._scopes = scopes or []
self._profile = profile
self._timeout = timeout

@property
def filename(self) -> str:
Expand Down Expand Up @@ -985,6 +1005,7 @@ def load(self) -> Optional[SessionCredentials]:
client_id=self._client_id,
client_secret=self._client_secret,
redirect_url=self._redirect_url,
timeout=self._timeout,
)
except Exception:
return None
Expand Down
Loading
Loading