Skip to content

Commit 2b04529

Browse files
committed
fix added
1 parent 84bf2c9 commit 2b04529

File tree

2 files changed

+188
-102
lines changed

2 files changed

+188
-102
lines changed

src/databricks/labs/dqx/check_funcs.py

Lines changed: 166 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,44 @@
3636
HAS_LLM_DETECTOR = False
3737

3838

39+
def _get_search_columns(df: DataFrame, columns: list[str | Column]) -> list[str]:
40+
"""Determine search space for PK detection."""
41+
if not columns:
42+
return df.columns
43+
return get_columns_as_strings(columns, allow_simple_expressions_only=True)
44+
45+
46+
def _create_pk_detector(
47+
temp_table_name: str, llm_options: dict[str, Any], spark: SparkSession
48+
) -> "DatabricksPrimaryKeyDetector":
49+
"""Create and configure the primary key detector."""
50+
return DatabricksPrimaryKeyDetector(
51+
table=temp_table_name,
52+
endpoint=llm_options["llm_pk_detection_endpoint"],
53+
validate_duplicates=True,
54+
spark_session=spark,
55+
max_retries=3,
56+
show_live_reasoning=False,
57+
)
58+
59+
60+
def _validate_and_filter_detected_columns(
61+
pk_result: dict[str, Any], search_columns: list[str], columns: list[str | Column]
62+
) -> list[str]:
63+
"""Validate PK detection result and filter columns."""
64+
if not pk_result.get("success", False):
65+
raise RuntimeError(f"LLM-based primary key detection failed: {pk_result.get('error', 'Unknown error')}")
66+
67+
detected_columns = pk_result["primary_key_columns"]
68+
69+
if columns:
70+
detected_columns = [col for col in detected_columns if col in search_columns]
71+
if not detected_columns:
72+
raise RuntimeError(f"No primary key columns detected within the provided columns: {search_columns}")
73+
74+
return detected_columns
75+
76+
3977
def _detect_matching_keys_with_llm(
4078
df: DataFrame,
4179
columns: list[str | Column],
@@ -66,49 +104,22 @@ def _detect_matching_keys_with_llm(
66104
"Install with: pip install databricks-labs-dqx[llm]"
67105
)
68106

107+
search_columns = _get_search_columns(df, columns)
69108
temp_table_name = f"temp_pk_detection_{unique_id}"
70109
df.createOrReplaceTempView(temp_table_name)
71110

72111
try:
73-
# Determine search space for PK detection
74-
if not columns:
75-
search_columns = df.columns
76-
else:
77-
search_columns = get_columns_as_strings(columns, allow_simple_expressions_only=True)
78-
79-
# Create detector with options
80-
detector = DatabricksPrimaryKeyDetector(
81-
table=temp_table_name,
82-
endpoint=llm_options["llm_pk_detection_endpoint"],
83-
validate_duplicates=True, # Always validate for duplicates
84-
spark_session=spark,
85-
max_retries=3, # Fixed to 3 retries for optimal performance
86-
show_live_reasoning=False,
87-
)
88-
112+
detector = _create_pk_detector(temp_table_name, llm_options, spark)
89113
pk_result = detector.detect_primary_keys()
90-
91-
if not pk_result.get("success", False):
92-
raise RuntimeError(f"LLM-based primary key detection failed: {pk_result.get('error', 'Unknown error')}")
93-
94-
detected_columns = pk_result["primary_key_columns"]
95-
96-
# Filter detected columns to search space if columns were provided
97-
if columns:
98-
detected_columns = [col for col in detected_columns if col in search_columns]
99-
if not detected_columns:
100-
raise RuntimeError(f"No primary key columns detected within the provided columns: {search_columns}")
101-
114+
detected_columns = _validate_and_filter_detected_columns(pk_result, search_columns, columns)
102115
return detected_columns
103-
104116
except Exception as e:
105117
raise RuntimeError(f"Failed to detect primary keys using LLM: {str(e)}") from e
106118
finally:
107-
# Clean up temporary table
108119
try:
109120
spark.sql(f"DROP VIEW IF EXISTS {temp_table_name}")
110121
except Exception:
111-
pass # Ignore cleanup errors
122+
pass
112123

113124

114125
_IPV4_OCTET = r"(25[0-5]|2[0-4]\d|1\d{2}|[1-9]?\d)"
@@ -1477,20 +1488,105 @@ def is_aggr_not_equal(
14771488
)
14781489

14791490

1491+
def _prepare_compare_columns(
1492+
df: DataFrame,
1493+
ref_df: DataFrame,
1494+
pk_column_names: list[str],
1495+
exclude_column_names: list[str],
1496+
) -> tuple[list[str], list[str]]:
1497+
"""Prepare columns for comparison and determine skipped columns."""
1498+
map_type_columns = {field.name for field in df.schema.fields if isinstance(field.dataType, types.MapType)}
1499+
compare_columns = [
1500+
col
1501+
for col in df.columns
1502+
if (
1503+
col in ref_df.columns
1504+
and col not in pk_column_names
1505+
and col not in exclude_column_names
1506+
and col not in map_type_columns
1507+
)
1508+
]
1509+
skipped_columns = [col for col in df.columns if col not in compare_columns and col not in pk_column_names]
1510+
return compare_columns, skipped_columns
1511+
1512+
1513+
def _build_output_columns(
1514+
pk_column_names: list[str],
1515+
ref_pk_column_names: list[str],
1516+
compare_columns: list[str],
1517+
skipped_columns: list[str],
1518+
condition_col: str,
1519+
) -> list[Column]:
1520+
"""Build the output column list for the comparison result."""
1521+
coalesced_pk_columns = [
1522+
F.coalesce(F.col(f"df.{col}"), F.col(f"ref_df.{ref_col}")).alias(col)
1523+
for col, ref_col in zip(pk_column_names, ref_pk_column_names)
1524+
]
1525+
return [
1526+
*coalesced_pk_columns,
1527+
*[F.col(f"df.{col}").alias(col) for col in compare_columns],
1528+
*[F.col(f"df.{col}").alias(col) for col in skipped_columns],
1529+
F.col(condition_col),
1530+
]
1531+
1532+
1533+
def _extract_tolerances(abs_tol: float | None, rel_tol: float | None) -> tuple[float, float]:
1534+
"""Extract and validate tolerance values from individual parameters."""
1535+
abs_tolerance = 0.0 if abs_tol is None else abs_tol
1536+
rel_tolerance = 0.0 if rel_tol is None else rel_tol
1537+
if abs_tolerance < 0 or rel_tolerance < 0:
1538+
raise InvalidParameterError("Absolute and/or relative tolerances if provided must be non-negative")
1539+
return abs_tolerance, rel_tolerance
1540+
1541+
1542+
def _prepare_column_names(
1543+
columns: list[str | Column],
1544+
ref_columns: list[str | Column],
1545+
exclude_columns: list[str | Column] | None,
1546+
) -> tuple[list[str], list[str], list[str], str]:
1547+
"""Convert column inputs to strings and create check alias."""
1548+
pk_column_names = get_columns_as_strings(columns, allow_simple_expressions_only=True)
1549+
ref_pk_column_names = get_columns_as_strings(ref_columns, allow_simple_expressions_only=True)
1550+
exclude_column_names = (
1551+
get_columns_as_strings(exclude_columns, allow_simple_expressions_only=True) if exclude_columns else []
1552+
)
1553+
check_alias = normalize_col_str(f"datasets_diff_pk_{'_'.join(pk_column_names)}_ref_{'_'.join(ref_pk_column_names)}")
1554+
return pk_column_names, ref_pk_column_names, exclude_column_names, check_alias
1555+
1556+
1557+
def _create_temp_column_names(unique_id: str) -> dict[str, str]:
1558+
"""Create temporary column names for comparison."""
1559+
return {
1560+
"condition": f"__compare_status_{unique_id}",
1561+
"row_missing": f"__row_missing_{unique_id}",
1562+
"row_extra": f"__row_extra_{unique_id}",
1563+
"columns_changed": f"__columns_changed_{unique_id}",
1564+
"filter": f"__filter_{uuid.uuid4().hex}",
1565+
}
1566+
1567+
1568+
def _initialize_llm_options(llm_opts: dict[str, Any] | None) -> tuple[bool, dict[str, Any]]:
1569+
"""Initialize LLM options and determine if LLM detection is enabled."""
1570+
enable_detection = llm_opts is not None and llm_opts.get("enable", True)
1571+
if llm_opts is None:
1572+
llm_opts = {"llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct"}
1573+
return enable_detection, llm_opts
1574+
1575+
14801576
@register_rule("dataset")
14811577
def compare_datasets(
14821578
columns: list[str | Column],
14831579
ref_columns: list[str | Column],
14841580
ref_df_name: str | None = None,
14851581
ref_table: str | None = None,
1582+
*,
14861583
check_missing_records: bool | None = False,
14871584
exclude_columns: list[str | Column] | None = None,
14881585
null_safe_row_matching: bool | None = True,
14891586
null_safe_column_value_matching: bool | None = True,
14901587
row_filter: str | None = None,
14911588
abs_tolerance: float | None = None,
14921589
rel_tolerance: float | None = None,
1493-
enable_llm_matching_key_detection: bool = False,
14941590
llm_matching_key_detection_options: dict[str, Any] | None = None,
14951591
) -> tuple[Column, Callable]:
14961592
"""
@@ -1550,12 +1646,12 @@ def compare_datasets(
15501646
With tolerance=0.01 (1%):
15511647
100 vs 101 → equal (diff = 1, tolerance = 1)
15521648
2.001 vs 2.0099 → equal
1553-
enable_llm_matching_key_detection: Enable automatic matching key detection using LLM. When enabled,
1554-
the function will automatically detect primary/matching keys from the provided columns (or all
1555-
columns if columns is empty). Requires LLM dependencies to be installed.
15561649
llm_matching_key_detection_options: Optional dictionary of options for LLM-based matching key detection.
1650+
When provided, enables automatic matching key detection using LLM to detect primary/matching keys
1651+
from the provided columns (or all columns if columns is empty). Requires LLM dependencies.
15571652
Supported options:
15581653
- llm_pk_detection_endpoint: Databricks Model Serving endpoint (defaults to foundational model)
1654+
- enable: Set to True to enable LLM detection (default: True when options dict is provided)
15591655
15601656
Returns:
15611657
Tuple[Column, Callable]:
@@ -1571,101 +1667,69 @@ def compare_datasets(
15711667
- if *abs_tolerance* or *rel_tolerance* is negative.
15721668
"""
15731669
_validate_ref_params(columns, ref_columns, ref_df_name, ref_table)
1574-
1575-
abs_tolerance = 0.0 if abs_tolerance is None else abs_tolerance
1576-
rel_tolerance = 0.0 if rel_tolerance is None else rel_tolerance
1577-
if abs_tolerance < 0 or rel_tolerance < 0:
1578-
raise InvalidParameterError("Absolute and/or relative tolerances if provided must be non-negative")
1579-
1580-
# Initialize options if not provided and capture in closure
1581-
if llm_matching_key_detection_options is None:
1582-
llm_matching_key_detection_options = {
1583-
"llm_pk_detection_endpoint": "databricks-meta-llama-3-1-8b-instruct",
1584-
}
1585-
_llm_matching_key_detection_options = llm_matching_key_detection_options
1586-
1587-
# convert all input columns to strings
1588-
pk_column_names = get_columns_as_strings(columns, allow_simple_expressions_only=True)
1589-
ref_pk_column_names = get_columns_as_strings(ref_columns, allow_simple_expressions_only=True)
1590-
exclude_column_names = (
1591-
get_columns_as_strings(exclude_columns, allow_simple_expressions_only=True) if exclude_columns else []
1670+
abs_tolerance, rel_tolerance = _extract_tolerances(abs_tolerance, rel_tolerance)
1671+
pk_column_names, ref_pk_column_names, exclude_column_names, check_alias = _prepare_column_names(
1672+
columns, ref_columns, exclude_columns
15921673
)
1593-
check_alias = normalize_col_str(f"datasets_diff_pk_{'_'.join(pk_column_names)}_ref_{'_'.join(ref_pk_column_names)}")
15941674

1675+
enable_llm_detection, _llm_options = _initialize_llm_options(llm_matching_key_detection_options)
15951676
unique_id = uuid.uuid4().hex
1596-
condition_col = f"__compare_status_{unique_id}"
1597-
row_missing_col = f"__row_missing_{unique_id}"
1598-
row_extra_col = f"__row_extra_{unique_id}"
1599-
columns_changed_col = f"__columns_changed_{unique_id}"
1600-
filter_col = f"__filter_{uuid.uuid4().hex}"
1677+
internal_column_names = _create_temp_column_names(unique_id)
16011678

16021679
def apply(df: DataFrame, spark: SparkSession, ref_dfs: dict[str, DataFrame]) -> DataFrame:
16031680
nonlocal columns, ref_columns
16041681

16051682
ref_df = _get_ref_df(ref_df_name, ref_table, ref_dfs, spark)
16061683

1607-
# Auto-detect matching keys using LLM if enabled
1608-
if enable_llm_matching_key_detection:
1609-
detected_columns = _detect_matching_keys_with_llm(
1610-
df, columns, _llm_matching_key_detection_options, spark, unique_id
1611-
)
1684+
if enable_llm_detection:
1685+
detected_columns = _detect_matching_keys_with_llm(df, columns, _llm_options, spark, unique_id)
16121686
columns = detected_columns # type: ignore[assignment]
16131687
ref_columns = detected_columns # type: ignore[assignment]
16141688

1615-
# map type columns must be skipped as they cannot be compared with eqNullSafe
1616-
map_type_columns = {field.name for field in df.schema.fields if isinstance(field.dataType, types.MapType)}
1617-
1618-
# columns to compare: present in both df and ref_df, not in PK, not excluded, not map type
1619-
compare_columns = [
1620-
col
1621-
for col in df.columns
1622-
if (
1623-
col in ref_df.columns
1624-
and col not in pk_column_names
1625-
and col not in exclude_column_names
1626-
and col not in map_type_columns
1627-
)
1628-
]
1629-
1630-
# determine skipped columns: present in df, not compared, and not PK
1631-
skipped_columns = [col for col in df.columns if col not in compare_columns and col not in pk_column_names]
1632-
1633-
# apply filter before aliasing to avoid ambiguity
1634-
df = df.withColumn(filter_col, F.expr(row_filter) if row_filter else F.lit(True))
1635-
1689+
compare_columns, skipped_columns = _prepare_compare_columns(df, ref_df, pk_column_names, exclude_column_names)
1690+
df = df.withColumn(internal_column_names["filter"], F.expr(row_filter) if row_filter else F.lit(True))
16361691
df = df.alias("df")
16371692
ref_df = ref_df.alias("ref_df")
16381693

16391694
results = _match_rows(
16401695
df, ref_df, pk_column_names, ref_pk_column_names, check_missing_records, null_safe_row_matching
16411696
)
1642-
results = _add_row_diffs(results, pk_column_names, ref_pk_column_names, row_missing_col, row_extra_col)
1697+
results = _add_row_diffs(
1698+
results,
1699+
pk_column_names,
1700+
ref_pk_column_names,
1701+
internal_column_names["row_missing"],
1702+
internal_column_names["row_extra"],
1703+
)
16431704
results = _add_column_diffs(
1644-
results, compare_columns, columns_changed_col, null_safe_column_value_matching, abs_tolerance, rel_tolerance
1705+
results,
1706+
compare_columns,
1707+
internal_column_names["columns_changed"],
1708+
null_safe_column_value_matching,
1709+
abs_tolerance,
1710+
rel_tolerance,
16451711
)
16461712
results = _add_compare_condition(
1647-
results, condition_col, row_missing_col, row_extra_col, columns_changed_col, filter_col
1713+
results,
1714+
internal_column_names["condition"],
1715+
internal_column_names["row_missing"],
1716+
internal_column_names["row_extra"],
1717+
internal_column_names["columns_changed"],
1718+
internal_column_names["filter"],
16481719
)
16491720

1650-
# in a full outer join, rows may be missing from either side, we take the first non-null value
1651-
coalesced_pk_columns = [
1652-
F.coalesce(F.col(f"df.{col}"), F.col(f"ref_df.{ref_col}")).alias(col)
1653-
for col, ref_col in zip(pk_column_names, ref_pk_column_names)
1654-
]
1655-
1656-
# make sure original columns + condition column are present in the output
1657-
return results.select(
1658-
*coalesced_pk_columns,
1659-
*[F.col(f"df.{col}").alias(col) for col in compare_columns],
1660-
*[F.col(f"df.{col}").alias(col) for col in skipped_columns],
1661-
F.col(condition_col),
1721+
output_columns = _build_output_columns(
1722+
pk_column_names, ref_pk_column_names, compare_columns, skipped_columns, internal_column_names["condition"]
16621723
)
1724+
return results.select(*output_columns)
16631725

1664-
condition = F.col(condition_col).isNotNull()
1726+
condition = F.col(internal_column_names["condition"]).isNotNull()
16651727

16661728
return (
16671729
make_condition(
1668-
condition=condition, message=F.when(condition, F.to_json(F.col(condition_col))), alias=check_alias
1730+
condition=condition,
1731+
message=F.when(condition, F.to_json(F.col(internal_column_names["condition"]))),
1732+
alias=check_alias,
16691733
),
16701734
apply,
16711735
)

src/databricks/labs/dqx/llm/resources/yaml_checks_examples.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,28 @@
663663
check_missing_records: true
664664
null_safe_row_matching: true
665665
null_safe_column_value_matching: true
666+
- criticality: error
667+
check:
668+
function: compare_datasets
669+
arguments:
670+
columns: []
671+
ref_columns: []
672+
ref_df_name: ref_df_key
673+
enable_llm_matching_key_detection: true
674+
- criticality: error
675+
check:
676+
function: compare_datasets
677+
arguments:
678+
columns:
679+
- customer_id
680+
- order_id
681+
ref_columns:
682+
- customer_id
683+
- order_id
684+
ref_df_name: ref_df_key
685+
enable_llm_matching_key_detection: true
686+
llm_matching_key_detection_options:
687+
llm_pk_detection_endpoint: my-custom-endpoint
666688
- criticality: error
667689
check:
668690
function: is_data_fresh_per_time_window

0 commit comments

Comments
 (0)