Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions python/hsfs/storage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3365,13 +3365,45 @@ def __init__(
default_catalog: str | None = None,
aws_region: str | None = None,
arguments: list[dict[str, Any]] | dict[str, Any] | None = None,
auth_method: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
oauth_endpoint: str | None = None,
account_id: str | None = None,
account_host: str | None = None,
has_access_token: bool | None = None,
has_client_secret: bool | None = None,
**kwargs: Any,
) -> None:
super().__init__(id, name, description, featurestore_id)
self._workspace_url = workspace_url
self._access_token = access_token
self._default_catalog = default_catalog
self._aws_region = aws_region
# auth_method defaults to 'PAT' for back-compat with connectors created
# before OAuth support landed; oauth_endpoint defaults to 'WORKSPACE'
# when caller asks for OAUTH_M2M without specifying one.
if auth_method is None:
self._auth_method = "PAT"
else:
self._auth_method = auth_method
self._client_id = client_id
self._client_secret = client_secret
if self._auth_method == "OAUTH_M2M" and oauth_endpoint is None:
self._oauth_endpoint = "WORKSPACE"
else:
self._oauth_endpoint = oauth_endpoint
self._account_id = account_id
self._account_host = account_host
# has_access_token / has_client_secret are server-emitted booleans that
# let callers tell whether a secret is on file without exposing it.
# They are never sent back on write (the backend ignores them).
self._has_access_token = (
bool(has_access_token) if has_access_token is not None else (access_token is not None)
)
self._has_client_secret = (
bool(has_client_secret) if has_client_secret is not None else (client_secret is not None)
)
if isinstance(arguments, list):
# Match the other connectors in this file: tolerate name-only entries
# and skip entries without a name. Backend serialises these as a list
Expand Down Expand Up @@ -3418,6 +3450,74 @@ def arguments(self) -> dict[str, Any]:
"""Additional Unity Catalog connection arguments passed through to the Arrow Flight server."""
return self._arguments

@public
@property
def auth_method(self) -> str:
"""Authentication method for the Databricks workspace, either "PAT" or "OAUTH_M2M".

Defaults to "PAT" for connectors created before OAuth support landed.
"""
return self._auth_method

@public
@property
def client_id(self) -> str | None:
"""Databricks service principal client ID, only set when [`auth_method`][hsfs.storage_connector.UnityCatalogConnector.auth_method] is "OAUTH_M2M"."""
return self._client_id

@public
@property
def client_secret(self) -> str | None:
"""Databricks service principal client secret.

Write-only on the backend: this property is only populated when the
caller has just constructed the connector locally with a secret in hand.
Server responses never carry it; use [`has_client_secret`][hsfs.storage_connector.UnityCatalogConnector.has_client_secret] to test
whether a secret is on file.
"""
return self._client_secret

@public
@property
def oauth_endpoint(self) -> str | None:
"""OAuth token endpoint flavour, either "WORKSPACE" or "ACCOUNT".

Only set when [`auth_method`][hsfs.storage_connector.UnityCatalogConnector.auth_method] is "OAUTH_M2M".
"""
return self._oauth_endpoint

@public
@property
def account_id(self) -> str | None:
"""Databricks account ID, only set when [`oauth_endpoint`][hsfs.storage_connector.UnityCatalogConnector.oauth_endpoint] is "ACCOUNT"."""
return self._account_id

@public
@property
def account_host(self) -> str | None:
"""Databricks account-console host, only set when [`oauth_endpoint`][hsfs.storage_connector.UnityCatalogConnector.oauth_endpoint] is "ACCOUNT"."""
return self._account_host

@public
@property
def has_access_token(self) -> bool:
"""True iff a personal access token is on file for this connector.

The server never returns the access token itself on read; this boolean
lets callers tell whether one exists without exposing the secret.
"""
return self._has_access_token

@public
@property
def has_client_secret(self) -> bool:
"""True iff a client secret is on file for this connector.

The server never returns the client secret itself on read; this boolean
lets callers tell whether one exists without exposing the secret.
"""
return self._has_client_secret

@public
def connector_options(self) -> dict[str, Any]:
"""Return UC connector options shaped for external library use."""
Expand Down
51 changes: 50 additions & 1 deletion python/tests/fixtures/storage_connector_fixtures.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@
"name": "test_unity_catalog",
"storageConnectorType": "UNITY_CATALOG",
"workspace_url": "https://test.cloud.databricks.com",
"access_token": "dapi-test-token",
"authMethod": "PAT",
"hasAccessToken": true,
"default_catalog": "test_catalog",
"aws_region": "us-west-2",
"arguments": [{"name": "arg1", "value": "val1"}]
Expand Down Expand Up @@ -185,6 +186,54 @@
},
"headers": null
},
"get_unity_catalog_oauth_workspace": {
"response": {
"type": "featurestoreUnityCatalogConnectorDTO",
"featurestoreId": 67,
"id": 1,
"name": "test_unity_catalog_oauth",
"storageConnectorType": "UNITY_CATALOG",
"workspace_url": "https://test.cloud.databricks.com",
"authMethod": "OAUTH_M2M",
"oauthEndpoint": "WORKSPACE",
"clientId": "test-sp-client-id",
"hasClientSecret": true,
"default_catalog": "test_catalog"
},
"method": "GET",
"path_params": [],
"query_params": null,
"headers": null
},
"get_unity_catalog_oauth_account": {
"response": {
"type": "featurestoreUnityCatalogConnectorDTO",
"featurestoreId": 67,
"id": 2,
"name": "test_unity_catalog_account",
"storageConnectorType": "UNITY_CATALOG",
"workspace_url": "https://test.cloud.databricks.com",
"authMethod": "OAUTH_M2M",
"oauthEndpoint": "ACCOUNT",
"clientId": "test-sp-client-id",
"hasClientSecret": true,
"accountId": "12345678-1234-1234-1234-1234567890ab",
"accountHost": "accounts.cloud.databricks.com"
},
"method": "GET",
"path_params": [
"project",
"119",
"featurestores",
67,
"storageconnectors",
"test_unity_catalog"
],
"query_params": {
"temporaryCredentials": true
},
"headers": null
},
"get_redshift_basic_info": {
"response": {
"type": "featurestoreRedshiftConnectorDTO",
Expand Down
74 changes: 73 additions & 1 deletion python/tests/test_storage_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,51 @@ def test_from_response_json(self, backend_fixtures):
assert sc.description == "Unity Catalog connector description"
assert sc.type == storage_connector.StorageConnector.UNITY_CATALOG
assert sc.workspace_url == "https://test.cloud.databricks.com"
assert sc.access_token == "dapi-test-token"
# access_token itself is write-only on the backend; the server never
# returns it on GET. hasAccessToken signals that one is on file.
assert sc.access_token is None
assert sc.has_access_token is True
assert sc.auth_method == "PAT"
assert sc.default_catalog == "test_catalog"
assert sc.aws_region == "us-west-2"
assert sc.arguments == {"arg1": "val1"}

def test_from_response_json_oauth_workspace(self, backend_fixtures):
# Arrange
json = backend_fixtures["storage_connector"][
"get_unity_catalog_oauth_workspace"
]["response"]

# Act
sc = storage_connector.StorageConnector.from_response_json(json)

# Assert
assert sc.auth_method == "OAUTH_M2M"
assert sc.oauth_endpoint == "WORKSPACE"
assert sc.client_id == "test-sp-client-id"
assert sc.client_secret is None
assert sc.has_client_secret is True
assert sc.account_id is None
assert sc.account_host is None

def test_from_response_json_oauth_account(self, backend_fixtures):
# Arrange
json = backend_fixtures["storage_connector"][
"get_unity_catalog_oauth_account"
]["response"]

# Act
sc = storage_connector.StorageConnector.from_response_json(json)

# Assert
assert sc.auth_method == "OAUTH_M2M"
assert sc.oauth_endpoint == "ACCOUNT"
assert sc.client_id == "test-sp-client-id"
assert sc.client_secret is None
assert sc.has_client_secret is True
assert sc.account_id == "12345678-1234-1234-1234-1234567890ab"
assert sc.account_host == "accounts.cloud.databricks.com"

def test_from_response_json_basic_info(self, backend_fixtures):
# Arrange
json = backend_fixtures["storage_connector"]["get_unity_catalog_basic_info"][
Expand Down Expand Up @@ -205,6 +245,38 @@ def test_spark_options_not_supported(self):
with pytest.raises(NotImplementedError):
sc.spark_options()

def test_legacy_construction_defaults_pat(self):
# Connectors built before OAuth support landed have no auth_method
# field at all. They must keep working as PAT.
sc = storage_connector.UnityCatalogConnector(
id=1,
name="uc",
featurestore_id=1,
workspace_url="https://ws.cloud.databricks.com",
access_token="dapi-xyz",
)
assert sc.auth_method == "PAT"
assert sc.oauth_endpoint is None
assert sc.client_id is None
assert sc.has_access_token is True
assert sc.has_client_secret is False

def test_oauth_construction_defaults_workspace_endpoint(self):
# auth_method=OAUTH_M2M without oauth_endpoint defaults to WORKSPACE,
# matching the frontend default.
sc = storage_connector.UnityCatalogConnector(
id=1,
name="uc",
featurestore_id=1,
workspace_url="https://ws.cloud.databricks.com",
auth_method="OAUTH_M2M",
client_id="cid",
client_secret="csec",
)
assert sc.oauth_endpoint == "WORKSPACE"
assert sc.client_secret == "csec"
assert sc.has_client_secret is True


class TestRedshiftConnector:
def test_from_response_json(self, backend_fixtures):
Expand Down
Loading