Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions src/databricks/labs/lakebridge/reconcile/compare.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from functools import reduce
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, expr, lit

from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
from databricks.labs.lakebridge.reconcile.exception import ColumnMismatchException
from databricks.labs.lakebridge.reconcile.recon_capture import (
ReconIntermediatePersist,
AbstractReconIntermediatePersist,
)
from databricks.labs.lakebridge.reconcile.recon_output_config import (
DataReconcileOutput,
Expand Down Expand Up @@ -58,8 +58,7 @@ def reconcile_data(
target: DataFrame,
key_columns: list[str],
report_type: str,
spark: SparkSession,
path: str,
persistence: AbstractReconIntermediatePersist,
) -> DataReconcileOutput:
source_alias = "src"
target_alias = "tgt"
Expand All @@ -78,9 +77,8 @@ def reconcile_data(
)
)

# Write unmatched df to volume
df = ReconIntermediatePersist(spark, path).write_and_read_unmatched_df_with_volumes(df)
logger.debug(f"Unmatched data was written to {path} successfully")
df = persistence.write_and_read_df_with_volumes(df)
# Checkpoint after joining source and target to backpressure

mismatch = _get_mismatch_data(df, source_alias, target_alias) if report_type in {"all", "data"} else None

Expand Down Expand Up @@ -170,6 +168,7 @@ def capture_mismatch_data_and_columns(source: DataFrame, target: DataFrame, key_

check_columns = [column for column in source_columns if column not in unnormalized_key_columns]
mismatch_df = _get_mismatch_df(source_df, target_df, unnormalized_key_columns, check_columns)
# TODO write `mismatch_df` to delta
mismatch_columns = _get_mismatch_columns(mismatch_df, check_columns)
return MismatchOutput(mismatch_df, mismatch_columns)

Expand Down Expand Up @@ -395,12 +394,14 @@ def reconcile_agg_data_per_rule(
missing_in_src = joined_df_with_rule_cols.filter(_agg_conditions(rule_select_columns, "missing_in_src")).select(
*rule_target_columns
)
# TODO write `missing_in_tgt` to delta

# Data missing in Target DataFrame
rule_source_columns = set(source_columns).intersection([mapping.source_name for mapping in rule_select_columns])
missing_in_tgt = joined_df_with_rule_cols.filter(_agg_conditions(rule_select_columns, "missing_in_tgt")).select(
*rule_source_columns
)
# TODO write `missing_in_tgt` to delta

mismatch_count = 0
if mismatch:
Expand All @@ -422,8 +423,7 @@ def join_aggregate_data(
source: DataFrame,
target: DataFrame,
key_columns: list[str] | None,
spark: SparkSession,
path: str,
persistence: AbstractReconIntermediatePersist,
) -> DataFrame:
# TODO: Integrate with reconcile_data function

Expand All @@ -450,9 +450,5 @@ def join_aggregate_data(
joined_cols = source.columns + target.columns
normalized_joined_cols = [DialectUtils.ansi_normalize_identifier(col) for col in joined_cols]
joined_df = df.select(*normalized_joined_cols)

# Write the joined df to volume path
joined_volume_df = ReconIntermediatePersist(spark, path).write_and_read_unmatched_df_with_volumes(joined_df).cache()
logger.warning(f"Unmatched data is written to {path} successfully")

return joined_volume_df
persisted = persistence.write_and_read_df_with_volumes(joined_df)
return persisted
87 changes: 55 additions & 32 deletions src/databricks/labs/lakebridge/reconcile/recon_capture.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import tempfile
import uuid
from datetime import datetime
from functools import reduce
from functools import reduce, cached_property
from pathlib import Path

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, collect_list, create_map, lit
from pyspark.sql.types import StringType, StructField, StructType
from pyspark.errors import PySparkException
from sqlglot import Dialect

Expand All @@ -14,7 +16,6 @@
from databricks.labs.lakebridge.reconcile.exception import (
WriteToTableException,
ReadAndWriteWithVolumeException,
CleanFromVolumeException,
)
from databricks.labs.lakebridge.reconcile.recon_output_config import (
DataReconcileOutput,
Expand All @@ -38,44 +39,66 @@
_RECON_AGGREGATE_DETAILS_TABLE_NAME = "aggregate_details"


class ReconIntermediatePersist:
class AbstractReconIntermediatePersist:
@property
def base_dir(self) -> Path:
raise NotImplementedError

def __init__(self, spark: SparkSession, path: str):
self.spark = spark
self.path = path

def _write_unmatched_df_to_volumes(
def write_and_read_df_with_volumes(
self,
unmatched_df: DataFrame,
) -> None:
unmatched_df.write.format("parquet").mode("overwrite").save(self.path)
df: DataFrame,
) -> DataFrame:
raise NotImplementedError

def _read_unmatched_df_from_volumes(self) -> DataFrame:
return self.spark.read.format("parquet").load(self.path)

def clean_unmatched_df_from_volume(self):
try:
# TODO: for now we are overwriting the intermediate cache path. We should delete the volume in future
# workspace_client.dbfs.get_status(path)
# workspace_client.dbfs.delete(path, recursive=True)
empty_df = self.spark.createDataFrame([], schema=StructType([StructField("empty", StringType(), True)]))
empty_df.write.format("parquet").mode("overwrite").save(self.path)
logger.debug(f"Unmatched DF cleaned up from {self.path} successfully.")
except PySparkException as e:
message = f"Error cleaning up unmatched DF from {self.path} volumes --> {e}"
logger.error(message)
raise CleanFromVolumeException(message) from e
class ReconIntermediatePersist(AbstractReconIntermediatePersist):
def __init__(self, spark: SparkSession, metadata_config: ReconcileMetadataConfig):
self._spark = spark
self._metadata_config = metadata_config
self._format = "delta" if self._is_databricks else "parquet"
self._base_dir = self._get_uc_volume_path if self._is_databricks else tempfile.gettempdir()

@cached_property
def _is_databricks(self) -> bool:
is_db = any(k.startswith("spark.databricks") for k in self._spark.conf.getAll.keys())
logger.info(f"Running on Databricks check completed with result: {is_db}")
return is_db
Comment on lines +61 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would create this method outside the class.

Suggested change
@cached_property
def _is_databricks(self) -> bool:
is_db = any(k.startswith("spark.databricks") for k in self._spark.conf.getAll.keys())
logger.info(f"Running on Databricks check completed with result: {is_db}")
return is_db
@lru_cache(maxsize=1)
def is_databricks(spark: SparkSession) -> bool:
is_db = any(k.startswith("spark.databricks") for k in spark.conf.getAll.keys())
logger.info(f"Running on Databricks check completed with result: {is_db}")
return is_db


def write_and_read_unmatched_df_with_volumes(
@property
def base_dir(self) -> Path:
return Path(self._base_dir)

@property
def _get_uc_volume_path(self):
return (
f"/Volumes/"
f"{self._metadata_config.catalog}/"
f"{self._metadata_config.schema}/"
f"{self._metadata_config.volume}"
)

def _write_df_to_volumes(self, df: DataFrame, path: str) -> None:
logger.debug(f"Writing DF on {self._format} to path: {path}")
df.write.format(self._format).save(path)
logger.info(f"Wrote DF on {self._format}")

def _read_df_from_volumes(self, path) -> DataFrame:
logger.debug(f"Reading DF on {self._format} from path: {path}")
df = self._spark.read.format(self._format).load(path)
logger.info(f"Read DF on {self._format}")
return df

def write_and_read_df_with_volumes(
self,
unmatched_df: DataFrame,
df: DataFrame,
) -> DataFrame:
path = str(self.base_dir / uuid.uuid4().hex)
try:
self._write_unmatched_df_to_volumes(unmatched_df)
return self._read_unmatched_df_from_volumes()
self._write_df_to_volumes(df, path)
return self._read_df_from_volumes(path)
except PySparkException as e:
message = f"Exception in reading or writing unmatched DF with volumes {self.path} --> {e}"
logger.error(message)
message = f"Exception in reading or writing DF at: {path}"
logger.exception(message)
raise ReadAndWriteWithVolumeException(message) from e


Expand Down
17 changes: 8 additions & 9 deletions src/databricks/labs/lakebridge/reconcile/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
DatabaseConfig,
ReconcileMetadataConfig,
)
from databricks.labs.lakebridge.reconcile import utils
from databricks.labs.lakebridge.reconcile.compare import (
capture_mismatch_data_and_columns,
reconcile_data,
Expand All @@ -28,6 +27,7 @@
from databricks.labs.lakebridge.reconcile.query_builder.threshold_query import (
ThresholdQueryBuilder,
)
from databricks.labs.lakebridge.reconcile.recon_capture import AbstractReconIntermediatePersist
from databricks.labs.lakebridge.reconcile.recon_config import (
Schema,
Table,
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
source_engine: Dialect,
spark: SparkSession,
metadata_config: ReconcileMetadataConfig,
intermediate_persist: AbstractReconIntermediatePersist,
):
self._source = source
self._target = target
Expand All @@ -69,6 +70,7 @@ def __init__(
self._source_engine = source_engine
self._spark = spark
self._metadata_config = metadata_config
self.intermediate_persist = intermediate_persist

@property
def source(self) -> DataSource:
Expand Down Expand Up @@ -143,14 +145,12 @@ def _get_reconcile_output(
options=table_conf.jdbc_reader_options,
)

volume_path = utils.generate_volume_path(table_conf, self._metadata_config)
return reconcile_data(
source=src_data,
target=tgt_data,
key_columns=table_conf.join_columns,
report_type=self._report_type,
spark=self._spark,
path=volume_path,
persistence=self.intermediate_persist,
)

def _get_reconcile_aggregate_output(
Expand Down Expand Up @@ -230,8 +230,6 @@ def _get_reconcile_aggregate_output(
self._target,
).build_queries()

volume_path = utils.generate_volume_path(table_conf, self._metadata_config)

table_agg_output: list[AggregateQueryOutput] = []

# Iterate over the grouped aggregates and reconcile the data
Expand Down Expand Up @@ -266,8 +264,7 @@ def _get_reconcile_aggregate_output(
source=src_data,
target=tgt_data,
key_columns=src_query_with_rules.group_by_columns,
spark=self._spark,
path=f"{volume_path}{src_query_with_rules.group_by_columns_as_str}",
persistence=self.intermediate_persist,
)
except DataSourceRuntimeException as e:
data_source_exception = e
Expand Down Expand Up @@ -370,7 +367,8 @@ def _get_mismatch_data(

# Uses pre-calculated `mismatch_count` from `reconcile_output.mismatch_count` to avoid from recomputing `mismatch` for RandomSampler.
mismatch_sampler = SamplerFactory.get_sampler(sampling_options)
df = mismatch_sampler.sample(mismatch, mismatch_count, key_columns, sampling_model_target).cache()
df = mismatch_sampler.sample(mismatch, mismatch_count, key_columns, sampling_model_target)
# TODO write `df` to delta

src_mismatch_sample_query = src_sampler.build_query(df)
tgt_mismatch_sample_query = tgt_sampler.build_query(df)
Expand Down Expand Up @@ -456,6 +454,7 @@ def _compute_threshold_comparison(self, table_conf: Table, src_schema: list[Sche
["`" + DialectUtils.unnormalize_identifier(name) + "_match` = 'Failed'" for name in threshold_columns]
)
mismatched_df = threshold_result.filter(failed_where_cond)
# TODO write `mismatched_df` to delta
mismatched_count = mismatched_df.count()
threshold_df = None
if mismatched_count > 0:
Expand Down
Loading
Loading