Skip to content

Commit 40733dd

Browse files
committed
persist in agg recon after source and target join
1 parent 72980db commit 40733dd

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

src/databricks/labs/lakebridge/reconcile/compare.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
77
from databricks.labs.lakebridge.reconcile.exception import ColumnMismatchException
8-
from databricks.labs.lakebridge.reconcile.recon_capture import AbstractReconIntermediatePersist
8+
from databricks.labs.lakebridge.reconcile.recon_capture import (
9+
AbstractReconIntermediatePersist,
10+
)
911
from databricks.labs.lakebridge.reconcile.recon_output_config import (
1012
DataReconcileOutput,
1113
MismatchOutput,
@@ -56,7 +58,7 @@ def reconcile_data(
5658
target: DataFrame,
5759
key_columns: list[str],
5860
report_type: str,
59-
inter_persist: AbstractReconIntermediatePersist,
61+
persistence: AbstractReconIntermediatePersist,
6062
) -> DataReconcileOutput:
6163
source_alias = "src"
6264
target_alias = "tgt"
@@ -75,7 +77,7 @@ def reconcile_data(
7577
)
7678
)
7779

78-
df = inter_persist.write_and_read_df_with_volumes(df)
80+
df = persistence.write_and_read_df_with_volumes(df)
7981
# Checkpoint after joining source and target to backpressure
8082

8183
mismatch = _get_mismatch_data(df, source_alias, target_alias) if report_type in {"all", "data"} else None
@@ -414,7 +416,12 @@ def reconcile_agg_data_per_rule(
414416
return rule_reconcile_output
415417

416418

417-
def join_aggregate_data(source: DataFrame, target: DataFrame, key_columns: list[str] | None) -> DataFrame:
419+
def join_aggregate_data(
420+
source: DataFrame,
421+
target: DataFrame,
422+
key_columns: list[str] | None,
423+
persistence: AbstractReconIntermediatePersist,
424+
) -> DataFrame:
418425
# TODO: Integrate with reconcile_data function
419426

420427
source_alias = "src"
@@ -439,5 +446,5 @@ def join_aggregate_data(source: DataFrame, target: DataFrame, key_columns: list[
439446

440447
joined_cols = source.columns + target.columns
441448
normalized_joined_cols = [DialectUtils.ansi_normalize_identifier(col) for col in joined_cols]
442-
joined_df = df.select(*normalized_joined_cols)
449+
joined_df = persistence.write_and_read_df_with_volumes(df.select(*normalized_joined_cols))
443450
return joined_df

src/databricks/labs/lakebridge/reconcile/reconciliation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _get_reconcile_output(
150150
target=tgt_data,
151151
key_columns=table_conf.join_columns,
152152
report_type=self._report_type,
153-
inter_persist=self.intermediate_persist,
153+
persistence=self.intermediate_persist,
154154
)
155155

156156
def _get_reconcile_aggregate_output(
@@ -264,6 +264,7 @@ def _get_reconcile_aggregate_output(
264264
source=src_data,
265265
target=tgt_data,
266266
key_columns=src_query_with_rules.group_by_columns,
267+
persistence=self.intermediate_persist,
267268
)
268269
except DataSourceRuntimeException as e:
269270
data_source_exception = e

tests/integration/reconcile/test_aggregates_reconcile.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6-
from unittest.mock import patch
76

87
import pytest
98
from pyspark.testing import assertDataFrameEqual

tests/integration/reconcile/test_compare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_compare_data_for_report_all(
4545
target=target,
4646
key_columns=["s_suppkey", "s_nationkey"],
4747
report_type="all",
48-
inter_persist=FakeReconIntermediatePersist(),
48+
persistence=FakeReconIntermediatePersist(),
4949
)
5050
expected = DataReconcileOutput(
5151
mismatch_count=1,
@@ -97,7 +97,7 @@ def test_compare_data_for_report_hash(mock_spark, tmp_path: Path):
9797
target=target,
9898
key_columns=["s_suppkey", "s_nationkey"],
9999
report_type="hash",
100-
inter_persist=FakeReconIntermediatePersist(),
100+
persistence=FakeReconIntermediatePersist(),
101101
)
102102
expected = DataReconcileOutput(
103103
missing_in_src=missing_in_src,
@@ -280,7 +280,7 @@ def test_compare_data_special_column_names(mock_spark, tmp_path: Path):
280280
target=target,
281281
key_columns=["`s``supp#`", "`s_nation#`"],
282282
report_type="all",
283-
inter_persist=FakeReconIntermediatePersist(),
283+
persistence=FakeReconIntermediatePersist(),
284284
)
285285
expected = DataReconcileOutput(
286286
mismatch_count=1,

0 commit comments

Comments
 (0)