Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@

from databricks.labs.lakebridge.connections.credential_manager import (
cred_file as creds,
CredentialManager,
create_credential_manager,
)
from databricks.labs.lakebridge.connections.database_manager import DatabaseManager
from databricks.labs.lakebridge.connections.env_getter import EnvGetter
from databricks.labs.lakebridge.assessments import CONNECTOR_REQUIRED

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,8 +42,8 @@ def __init__(
def _configure_credentials(self) -> str:
pass

@staticmethod
def _test_connection(source: str, cred_manager: CredentialManager):
def _test_connection(self, source: str):
cred_manager = create_credential_manager(self._credential_file)
config = cred_manager.get_credentials(source)

try:
Expand All @@ -67,9 +65,7 @@ def run(self):
logger.info(f"{source.capitalize()} details and credentials received.")
if CONNECTOR_REQUIRED.get(self._source_name, True):
if self.prompts.confirm(f"Do you want to test the connection to {source}?"):
cred_manager = create_credential_manager("lakebridge", EnvGetter())
if cred_manager:
self._test_connection(source, cred_manager)
self._test_connection(source)
logger.info(f"{source.capitalize()} Assessment Configuration Completed")


Expand Down
4 changes: 2 additions & 2 deletions src/databricks/labs/lakebridge/assessments/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from databricks.labs.lakebridge.connections.database_manager import DatabaseManager
from databricks.labs.lakebridge.connections.credential_manager import (
create_credential_manager,
cred_file,
)
from databricks.labs.lakebridge.connections.env_getter import EnvGetter
from databricks.labs.lakebridge.assessments import (
PRODUCT_NAME,
PRODUCT_PATH_PREFIX,
Expand Down Expand Up @@ -62,7 +62,7 @@ def profile(
def _setup_extractor(platform: str) -> DatabaseManager | None:
if not CONNECTOR_REQUIRED[platform]:
return None
cred_manager = create_credential_manager(PRODUCT_NAME, EnvGetter())
cred_manager = create_credential_manager(cred_file(PRODUCT_NAME))
connect_config = cred_manager.get_credentials(platform)
return DatabaseManager(platform, connect_config)

Expand Down
80 changes: 68 additions & 12 deletions src/databricks/labs/lakebridge/connections/credential_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from collections.abc import Callable
from functools import partial
from pathlib import Path
import logging
from typing import Protocol
import base64

import yaml

from databricks.sdk import WorkspaceClient
from databricks.sdk.errors import NotFound

from databricks.labs.lakebridge.connections.env_getter import EnvGetter


Expand All @@ -14,6 +20,12 @@ class SecretProvider(Protocol):
def get_secret(self, key: str) -> str:
pass

def get_secret_or_none(self, key: str) -> str | None:
try:
return self.get_secret(key)
except KeyError:
return None


class LocalSecretProvider(SecretProvider):
def get_secret(self, key: str) -> str:
Expand All @@ -32,18 +44,52 @@ def get_secret(self, key: str) -> str:
return key


class DatabricksSecretProvider:
class DatabricksSecretProvider(SecretProvider):
def __init__(self, ws: WorkspaceClient):
self._ws = ws

def get_databricks_secret(self, scope: str, key: str) -> str:
return self.get_secret(f"{scope}/{key}")

def get_secret(self, key: str) -> str:
raise NotImplementedError("Databricks secret vault not implemented")
"""Get the secret value given a secret scope & secret key.

:param key: key in the format 'scope/secret'
:return: The decoded UTF-8 secret value.

Raises:
NotFound: The secret could not be found.
UnicodeDecodeError: The secret value was not Base64-encoded UTF-8.
"""
scope, key_only = key.split(sep="/")
assert scope and key_only, "Secret name must be in the format 'scope/secret'"

try:
secret = self._ws.secrets.get_secret(scope, key_only)
assert secret.value is not None
return base64.b64decode(secret.value).decode("utf-8")
except NotFound as e:
raise KeyError(f'Secret does not exist with scope: {scope} and key: {key_only} : {e}') from e
except UnicodeDecodeError as e:
raise UnicodeDecodeError(
"utf-8",
key_only.encode(),
0,
1,
f"Secret {key} has Base64 bytes that cannot be decoded to utf-8 string: {e}.",
) from e


class CredentialManager:
def __init__(self, credentials: dict, secret_providers: dict[str, SecretProvider]):
SecretProviderFactory = Callable[[], SecretProvider]

def __init__(self, credentials: dict, secret_providers: dict[str, SecretProviderFactory]):
self._credentials = credentials
self._default_vault = self._credentials.get('secret_vault_type', 'local').lower()
self._provider = secret_providers.get(self._default_vault)
if not self._provider:
provider_factory = secret_providers.get(self._default_vault)
if not provider_factory:
raise ValueError(f"Unsupported secret vault type: {self._default_vault}")
self._provider = provider_factory()

def get_credentials(self, source: str) -> dict:
if source not in self._credentials:
Expand Down Expand Up @@ -76,14 +122,24 @@ def _load_credentials(path: Path) -> dict:
raise FileNotFoundError(f"Credentials file not found at {path}") from e


def create_credential_manager(product_name: str, env_getter: EnvGetter) -> CredentialManager:
creds_path = cred_file(product_name)
creds = _load_credentials(creds_path)
def create_databricks_secret_provider() -> DatabricksSecretProvider:
ws = WorkspaceClient()
return DatabricksSecretProvider(ws)


def create_credential_manager(creds_or_path: dict | Path | str) -> CredentialManager:
if isinstance(creds_or_path, str):
creds_or_path = Path(creds_or_path)
if isinstance(creds_or_path, Path):
creds = _load_credentials(creds_or_path)
else:
creds = creds_or_path

secret_providers = {
'local': LocalSecretProvider(),
'env': EnvSecretProvider(env_getter),
'databricks': DatabricksSecretProvider(),
# Lazily initialize secret providers
secret_providers: dict[str, CredentialManager.SecretProviderFactory] = {
'local': LocalSecretProvider,
'env': partial(EnvSecretProvider, EnvGetter()),
'databricks': create_databricks_secret_provider,
}

return CredentialManager(creds, secret_providers)
2 changes: 1 addition & 1 deletion src/databricks/labs/lakebridge/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _prompt_for_new_reconcile_installation(self) -> ReconcileConfig:
report_type = self._prompts.choice(
"Select the report type:", [report_type.value for report_type in ReconReportType]
)
scope_name = self._prompts.question(
scope_name = self._prompts.question( # TODO deprecate
f"Enter Secret scope name to store `{data_source.capitalize()}` connection details / secrets",
default=f"remorph_{data_source}",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from pyspark.sql import DataFrame

from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from sqlglot import Dialect

from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
from databricks.sdk import WorkspaceClient

Expand All @@ -36,20 +34,18 @@ def _get_schema_query(catalog: str, schema: str, table: str):
return re.sub(r'\s+', ' ', query)


class DatabricksDataSource(DataSource, SecretsMixin):
class DatabricksDataSource(DataSource):
_IDENTIFIER_DELIMITER = "`"

def __init__(
self,
engine: Dialect,
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope

def read_data(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
import dataclasses


@dataclasses.dataclass()
class NormalizedIdentifier:
ansi_normalized: str
source_normalized: str


class DialectUtils:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
class JDBCReaderMixin:
_spark: SparkSession

# TODO update the url
def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | None = None):
driver_class = {
"oracle": "oracle.jdbc.OracleDriver",
Expand Down

This file was deleted.

17 changes: 9 additions & 8 deletions src/databricks/labs/lakebridge/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
from pyspark.sql.functions import col
from sqlglot import Dialect

from databricks.labs.lakebridge.connections.credential_manager import DatabricksSecretProvider
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils, NormalizedIdentifier
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
from databricks.sdk import WorkspaceClient

logger = logging.getLogger(__name__)


class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
class OracleDataSource(DataSource, JDBCReaderMixin):
_DRIVER = "oracle"
_IDENTIFIER_DELIMITER = "\""
_SCHEMA_QUERY = """select column_name, case when (data_precision is not null
Expand All @@ -40,17 +39,19 @@ def __init__(
spark: SparkSession,
ws: WorkspaceClient,
secret_scope: str,
secrets: DatabricksSecretProvider, # only Databricks secrets are supported currently
):
self._engine = engine
self._spark = spark
self._ws = ws
self._secret_scope = secret_scope
self._secrets = secrets

@property
def get_jdbc_url(self) -> str:
return (
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._get_secret('host')}"
f":{self._get_secret('port')}/{self._get_secret('database')}"
f"jdbc:{OracleDataSource._DRIVER}:thin:@//{self._secrets.get_databricks_secret(self._secret_scope, 'host')}"
f":{self._secrets.get_databricks_secret(self._secret_scope, 'port')}/{self._secrets.get_databricks_secret(self._secret_scope, 'database')}"
)

def read_data(
Expand Down Expand Up @@ -108,8 +109,8 @@ def _get_timestamp_options() -> dict[str, str]:
}

def reader(self, query: str) -> DataFrameReader:
user = self._get_secret('user')
password = self._get_secret('password')
user = self._secrets.get_databricks_secret(self._secret_scope, 'user')
password = self._secrets.get_databricks_secret(self._secret_scope, 'password')
logger.debug(f"Using user: {user} to connect to Oracle")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
Expand Down
49 changes: 0 additions & 49 deletions src/databricks/labs/lakebridge/reconcile/connectors/secrets.py

This file was deleted.

Loading
Loading