From b5943090cf07ced5b7c3bb188a6efd4ae7952b74 Mon Sep 17 00:00:00 2001 From: Gourab Singha Date: Tue, 23 Jun 2026 00:47:27 +0530 Subject: [PATCH] fix: propagate http_timeout_seconds to oauth requests to prevent hangs --- databricks/sdk/credentials_provider.py | 11 ++++++++-- databricks/sdk/oauth.py | 29 ++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index dbe8bde1e..5e320854f 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -222,6 +222,7 @@ def get_notebook_pat_token() -> Optional[str]: host=cfg.host, scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, + timeout=cfg.http_timeout_seconds, ) def inner() -> Dict[str, str]: @@ -251,6 +252,7 @@ def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: use_header=True, disable_async=cfg.disable_async_token_refresh, authorization_details=cfg.authorization_details, + timeout=cfg.http_timeout_seconds, ) def inner() -> Dict[str, str]: @@ -277,7 +279,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: elif cfg.azure_client_id: client_id = cfg.azure_client_id client_secret = cfg.azure_client_secret - oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host) + oidc_endpoints = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds) if not client_id: client_id = "databricks-cli" oidc_endpoints = cfg.databricks_oidc_endpoints @@ -301,6 +303,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: redirect_url=redirect_url, scopes=scopes, profile=cfg.profile, + timeout=cfg.http_timeout_seconds, ) credentials = token_cache.load() if credentials: @@ -320,6 +323,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: redirect_url=redirect_url, client_secret=client_secret, scopes=scopes, + timeout=cfg.http_timeout_seconds, ) consent = oauth_client.initiate_consent() if not consent: @@ -367,6 +371,7 @@ def token_source_for(resource: str) -> oauth.TokenSource: disable_async=cfg.disable_async_token_refresh, scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, + timeout=cfg.http_timeout_seconds, ) _ensure_host_present(cfg, token_source_for) @@ -489,6 +494,7 @@ def token_source_for(audience: str) -> oauth.TokenSource: use_params=True, disable_async=cfg.disable_async_token_refresh, authorization_details=cfg.authorization_details, + timeout=cfg.http_timeout_seconds, ) def refreshed_headers() -> Dict[str, str]: @@ -551,7 +557,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: aad_endpoint = cfg.arm_environment.active_directory_endpoint if not cfg.azure_tenant_id: # detect Azure AD Tenant ID if it's not specified directly - token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host).token_endpoint + token_endpoint = get_azure_entra_id_workspace_endpoints(cfg.host, timeout=cfg.http_timeout_seconds).token_endpoint cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, "").split("/")[0] inner = oauth.ClientCredentials( @@ -567,6 +573,7 @@ def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: disable_async=cfg.disable_async_token_refresh, scopes=cfg.get_scopes_as_string(), authorization_details=cfg.authorization_details, + timeout=cfg.http_timeout_seconds, ) def refreshed_headers() -> Dict[str, str]: diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e5d034ae2..ee4b09cff 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -194,6 +194,7 @@ def retrieve_token( use_params=False, use_header=False, headers=None, + timeout=60, ) -> Token: logger.debug(f"Retrieving token for {client_id}") if use_params: @@ -206,7 +207,7 @@ def retrieve_token( auth = requests.auth.HTTPBasicAuth(client_id, client_secret) else: auth = IgnoreNetrcAuth() - resp = requests.post(token_url, params, auth=auth, headers=headers) + resp = requests.post(token_url, params, auth=auth, headers=headers, timeout=timeout) if not resp.ok: if resp.headers["Content-Type"].startswith("application/json"): err = resp.json() @@ -536,6 +537,7 @@ def get_unified_endpoints(host: str, account_id: str, client: _BaseClient = _Bas def get_azure_entra_id_workspace_endpoints( host: str, + timeout=60, ) -> Optional[OidcEndpoints]: """ Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks @@ -545,7 +547,7 @@ def get_azure_entra_id_workspace_endpoints( """ # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint host = _fix_host_if_needed(host) - res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False) + res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False, timeout=timeout) real_auth_url = res.headers.get("location") if not real_auth_url: return None @@ -564,11 +566,13 @@ def __init__( client_secret: str = None, redirect_url: str = None, disable_async: bool = True, + timeout: float = 60, ): self._token_endpoint = token_endpoint self._client_id = client_id self._client_secret = client_secret self._redirect_url = redirect_url + self._timeout = timeout super().__init__( token=token, disable_async=disable_async, @@ -584,6 +588,7 @@ def from_dict( client_id: str, client_secret: str = None, redirect_url: str = None, + timeout: float = 60, ) -> "SessionCredentials": return SessionCredentials( token=Token.from_dict(raw["token"]), @@ -591,6 +596,7 @@ def from_dict( client_id=client_id, client_secret=client_secret, redirect_url=redirect_url, + timeout=timeout, ) def auth_type(self): @@ -626,6 +632,7 @@ def refresh(self) -> Token: params=params, use_params=True, headers=headers, + timeout=self._timeout, ) @@ -639,6 +646,7 @@ def __init__( token_endpoint: str, client_id: str, client_secret: str = None, + timeout: float = 60, ) -> None: self._verifier = verifier self._state = state @@ -647,6 +655,7 @@ def __init__( self._token_endpoint = token_endpoint self._client_id = client_id self._client_secret = client_secret + self._timeout = timeout def as_dict(self) -> dict: return { @@ -663,7 +672,7 @@ def authorization_url(self) -> str: return self._authorization_url @staticmethod - def from_dict(raw: dict, client_secret: str = None) -> "Consent": + def from_dict(raw: dict, client_secret: str = None, timeout: float = 60) -> "Consent": return Consent( raw["state"], raw["verifier"], @@ -672,6 +681,7 @@ def from_dict(raw: dict, client_secret: str = None) -> "Consent": token_endpoint=raw["token_endpoint"], client_id=raw["client_id"], client_secret=client_secret, + timeout=timeout, ) def launch_external_browser(self) -> SessionCredentials: @@ -717,6 +727,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: params=params, headers=headers, use_params=True, + timeout=self._timeout, ) return SessionCredentials( token, @@ -724,6 +735,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: self._client_id, self._client_secret, self._redirect_url, + timeout=self._timeout, ) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): @@ -762,6 +774,7 @@ def __init__( client_id: str, scopes: List[str] = None, client_secret: str = None, + timeout: float = 60, ): if not scopes: # Default for direct OAuthClient users (e.g., via from_host()). @@ -775,6 +788,7 @@ def __init__( self._client_secret = client_secret self._oidc_endpoints = oidc_endpoints self._scopes = scopes + self._timeout = timeout @staticmethod def from_host( @@ -824,6 +838,7 @@ def initiate_consent(self) -> Consent: token_endpoint=self._oidc_endpoints.token_endpoint, client_id=self._client_id, client_secret=self._client_secret, + timeout=self._timeout, ) def __repr__(self) -> str: @@ -851,6 +866,7 @@ class ClientCredentials(Refreshable): use_header: bool = False disable_async: bool = True authorization_details: str = None + timeout: float = 60 def __post_init__(self): super().__init__(disable_async=self.disable_async) @@ -871,6 +887,7 @@ def refresh(self) -> Token: params, use_params=self.use_params, use_header=self.use_header, + timeout=self.timeout, ) @@ -897,6 +914,7 @@ class PATOAuthTokenExchange(Refreshable): scopes: str authorization_details: str = None disable_async: bool = True + timeout: float = 60 def __post_init__(self): super().__init__(disable_async=self.disable_async) @@ -913,7 +931,7 @@ def refresh(self) -> Token: if self.authorization_details: params["authorization_details"] = self.authorization_details - resp = requests.post(token_exchange_url, params) + resp = requests.post(token_exchange_url, params, timeout=self.timeout) if not resp.ok: if resp.headers["Content-Type"].startswith("application/json"): err = resp.json() @@ -947,6 +965,7 @@ def __init__( client_secret: Optional[str] = None, scopes: Optional[List[str]] = None, profile: Optional[str] = None, + timeout: float = 60, ) -> None: self._host = host self._client_id = client_id @@ -955,6 +974,7 @@ def __init__( self._client_secret = client_secret self._scopes = scopes or [] self._profile = profile + self._timeout = timeout @property def filename(self) -> str: @@ -985,6 +1005,7 @@ def load(self) -> Optional[SessionCredentials]: client_id=self._client_id, client_secret=self._client_secret, redirect_url=self._redirect_url, + timeout=self._timeout, ) except Exception: return None