3636 HAS_LLM_DETECTOR = False
3737
3838
39+ def _get_search_columns (df : DataFrame , columns : list [str | Column ]) -> list [str ]:
40+ """Determine search space for PK detection."""
41+ if not columns :
42+ return df .columns
43+ return get_columns_as_strings (columns , allow_simple_expressions_only = True )
44+
45+
46+ def _create_pk_detector (
47+ temp_table_name : str , llm_options : dict [str , Any ], spark : SparkSession
48+ ) -> "DatabricksPrimaryKeyDetector" :
49+ """Create and configure the primary key detector."""
50+ return DatabricksPrimaryKeyDetector (
51+ table = temp_table_name ,
52+ endpoint = llm_options ["llm_pk_detection_endpoint" ],
53+ validate_duplicates = True ,
54+ spark_session = spark ,
55+ max_retries = 3 ,
56+ show_live_reasoning = False ,
57+ )
58+
59+
60+ def _validate_and_filter_detected_columns (
61+ pk_result : dict [str , Any ], search_columns : list [str ], columns : list [str | Column ]
62+ ) -> list [str ]:
63+ """Validate PK detection result and filter columns."""
64+ if not pk_result .get ("success" , False ):
65+ raise RuntimeError (f"LLM-based primary key detection failed: { pk_result .get ('error' , 'Unknown error' )} " )
66+
67+ detected_columns = pk_result ["primary_key_columns" ]
68+
69+ if columns :
70+ detected_columns = [col for col in detected_columns if col in search_columns ]
71+ if not detected_columns :
72+ raise RuntimeError (f"No primary key columns detected within the provided columns: { search_columns } " )
73+
74+ return detected_columns
75+
76+
3977def _detect_matching_keys_with_llm (
4078 df : DataFrame ,
4179 columns : list [str | Column ],
@@ -66,49 +104,22 @@ def _detect_matching_keys_with_llm(
66104 "Install with: pip install databricks-labs-dqx[llm]"
67105 )
68106
107+ search_columns = _get_search_columns (df , columns )
69108 temp_table_name = f"temp_pk_detection_{ unique_id } "
70109 df .createOrReplaceTempView (temp_table_name )
71110
72111 try :
73- # Determine search space for PK detection
74- if not columns :
75- search_columns = df .columns
76- else :
77- search_columns = get_columns_as_strings (columns , allow_simple_expressions_only = True )
78-
79- # Create detector with options
80- detector = DatabricksPrimaryKeyDetector (
81- table = temp_table_name ,
82- endpoint = llm_options ["llm_pk_detection_endpoint" ],
83- validate_duplicates = True , # Always validate for duplicates
84- spark_session = spark ,
85- max_retries = 3 , # Fixed to 3 retries for optimal performance
86- show_live_reasoning = False ,
87- )
88-
112+ detector = _create_pk_detector (temp_table_name , llm_options , spark )
89113 pk_result = detector .detect_primary_keys ()
90-
91- if not pk_result .get ("success" , False ):
92- raise RuntimeError (f"LLM-based primary key detection failed: { pk_result .get ('error' , 'Unknown error' )} " )
93-
94- detected_columns = pk_result ["primary_key_columns" ]
95-
96- # Filter detected columns to search space if columns were provided
97- if columns :
98- detected_columns = [col for col in detected_columns if col in search_columns ]
99- if not detected_columns :
100- raise RuntimeError (f"No primary key columns detected within the provided columns: { search_columns } " )
101-
114+ detected_columns = _validate_and_filter_detected_columns (pk_result , search_columns , columns )
102115 return detected_columns
103-
104116 except Exception as e :
105117 raise RuntimeError (f"Failed to detect primary keys using LLM: { str (e )} " ) from e
106118 finally :
107- # Clean up temporary table
108119 try :
109120 spark .sql (f"DROP VIEW IF EXISTS { temp_table_name } " )
110121 except Exception :
111- pass # Ignore cleanup errors
122+ pass
112123
113124
114125_IPV4_OCTET = r"(25[0-5]|2[0-4]\d|1\d{2}|[1-9]?\d)"
@@ -1477,20 +1488,105 @@ def is_aggr_not_equal(
14771488 )
14781489
14791490
1491+ def _prepare_compare_columns (
1492+ df : DataFrame ,
1493+ ref_df : DataFrame ,
1494+ pk_column_names : list [str ],
1495+ exclude_column_names : list [str ],
1496+ ) -> tuple [list [str ], list [str ]]:
1497+ """Prepare columns for comparison and determine skipped columns."""
1498+ map_type_columns = {field .name for field in df .schema .fields if isinstance (field .dataType , types .MapType )}
1499+ compare_columns = [
1500+ col
1501+ for col in df .columns
1502+ if (
1503+ col in ref_df .columns
1504+ and col not in pk_column_names
1505+ and col not in exclude_column_names
1506+ and col not in map_type_columns
1507+ )
1508+ ]
1509+ skipped_columns = [col for col in df .columns if col not in compare_columns and col not in pk_column_names ]
1510+ return compare_columns , skipped_columns
1511+
1512+
1513+ def _build_output_columns (
1514+ pk_column_names : list [str ],
1515+ ref_pk_column_names : list [str ],
1516+ compare_columns : list [str ],
1517+ skipped_columns : list [str ],
1518+ condition_col : str ,
1519+ ) -> list [Column ]:
1520+ """Build the output column list for the comparison result."""
1521+ coalesced_pk_columns = [
1522+ F .coalesce (F .col (f"df.{ col } " ), F .col (f"ref_df.{ ref_col } " )).alias (col )
1523+ for col , ref_col in zip (pk_column_names , ref_pk_column_names )
1524+ ]
1525+ return [
1526+ * coalesced_pk_columns ,
1527+ * [F .col (f"df.{ col } " ).alias (col ) for col in compare_columns ],
1528+ * [F .col (f"df.{ col } " ).alias (col ) for col in skipped_columns ],
1529+ F .col (condition_col ),
1530+ ]
1531+
1532+
1533+ def _extract_tolerances (abs_tol : float | None , rel_tol : float | None ) -> tuple [float , float ]:
1534+ """Extract and validate tolerance values from individual parameters."""
1535+ abs_tolerance = 0.0 if abs_tol is None else abs_tol
1536+ rel_tolerance = 0.0 if rel_tol is None else rel_tol
1537+ if abs_tolerance < 0 or rel_tolerance < 0 :
1538+ raise InvalidParameterError ("Absolute and/or relative tolerances if provided must be non-negative" )
1539+ return abs_tolerance , rel_tolerance
1540+
1541+
1542+ def _prepare_column_names (
1543+ columns : list [str | Column ],
1544+ ref_columns : list [str | Column ],
1545+ exclude_columns : list [str | Column ] | None ,
1546+ ) -> tuple [list [str ], list [str ], list [str ], str ]:
1547+ """Convert column inputs to strings and create check alias."""
1548+ pk_column_names = get_columns_as_strings (columns , allow_simple_expressions_only = True )
1549+ ref_pk_column_names = get_columns_as_strings (ref_columns , allow_simple_expressions_only = True )
1550+ exclude_column_names = (
1551+ get_columns_as_strings (exclude_columns , allow_simple_expressions_only = True ) if exclude_columns else []
1552+ )
1553+ check_alias = normalize_col_str (f"datasets_diff_pk_{ '_' .join (pk_column_names )} _ref_{ '_' .join (ref_pk_column_names )} " )
1554+ return pk_column_names , ref_pk_column_names , exclude_column_names , check_alias
1555+
1556+
1557+ def _create_temp_column_names (unique_id : str ) -> dict [str , str ]:
1558+ """Create temporary column names for comparison."""
1559+ return {
1560+ "condition" : f"__compare_status_{ unique_id } " ,
1561+ "row_missing" : f"__row_missing_{ unique_id } " ,
1562+ "row_extra" : f"__row_extra_{ unique_id } " ,
1563+ "columns_changed" : f"__columns_changed_{ unique_id } " ,
1564+ "filter" : f"__filter_{ uuid .uuid4 ().hex } " ,
1565+ }
1566+
1567+
1568+ def _initialize_llm_options (llm_opts : dict [str , Any ] | None ) -> tuple [bool , dict [str , Any ]]:
1569+ """Initialize LLM options and determine if LLM detection is enabled."""
1570+ enable_detection = llm_opts is not None and llm_opts .get ("enable" , True )
1571+ if llm_opts is None :
1572+ llm_opts = {"llm_pk_detection_endpoint" : "databricks-meta-llama-3-1-8b-instruct" }
1573+ return enable_detection , llm_opts
1574+
1575+
14801576@register_rule ("dataset" )
14811577def compare_datasets (
14821578 columns : list [str | Column ],
14831579 ref_columns : list [str | Column ],
14841580 ref_df_name : str | None = None ,
14851581 ref_table : str | None = None ,
1582+ * ,
14861583 check_missing_records : bool | None = False ,
14871584 exclude_columns : list [str | Column ] | None = None ,
14881585 null_safe_row_matching : bool | None = True ,
14891586 null_safe_column_value_matching : bool | None = True ,
14901587 row_filter : str | None = None ,
14911588 abs_tolerance : float | None = None ,
14921589 rel_tolerance : float | None = None ,
1493- enable_llm_matching_key_detection : bool = False ,
14941590 llm_matching_key_detection_options : dict [str , Any ] | None = None ,
14951591) -> tuple [Column , Callable ]:
14961592 """
@@ -1550,12 +1646,12 @@ def compare_datasets(
15501646 With tolerance=0.01 (1%):
15511647 100 vs 101 → equal (diff = 1, tolerance = 1)
15521648 2.001 vs 2.0099 → equal
1553- enable_llm_matching_key_detection: Enable automatic matching key detection using LLM. When enabled,
1554- the function will automatically detect primary/matching keys from the provided columns (or all
1555- columns if columns is empty). Requires LLM dependencies to be installed.
15561649 llm_matching_key_detection_options: Optional dictionary of options for LLM-based matching key detection.
1650+ When provided, enables automatic matching key detection using LLM to detect primary/matching keys
1651+ from the provided columns (or all columns if columns is empty). Requires LLM dependencies.
15571652 Supported options:
15581653 - llm_pk_detection_endpoint: Databricks Model Serving endpoint (defaults to foundational model)
1654+ - enable: Set to True to enable LLM detection (default: True when options dict is provided)
15591655
15601656 Returns:
15611657 Tuple[Column, Callable]:
@@ -1571,101 +1667,69 @@ def compare_datasets(
15711667 - if *abs_tolerance* or *rel_tolerance* is negative.
15721668 """
15731669 _validate_ref_params (columns , ref_columns , ref_df_name , ref_table )
1574-
1575- abs_tolerance = 0.0 if abs_tolerance is None else abs_tolerance
1576- rel_tolerance = 0.0 if rel_tolerance is None else rel_tolerance
1577- if abs_tolerance < 0 or rel_tolerance < 0 :
1578- raise InvalidParameterError ("Absolute and/or relative tolerances if provided must be non-negative" )
1579-
1580- # Initialize options if not provided and capture in closure
1581- if llm_matching_key_detection_options is None :
1582- llm_matching_key_detection_options = {
1583- "llm_pk_detection_endpoint" : "databricks-meta-llama-3-1-8b-instruct" ,
1584- }
1585- _llm_matching_key_detection_options = llm_matching_key_detection_options
1586-
1587- # convert all input columns to strings
1588- pk_column_names = get_columns_as_strings (columns , allow_simple_expressions_only = True )
1589- ref_pk_column_names = get_columns_as_strings (ref_columns , allow_simple_expressions_only = True )
1590- exclude_column_names = (
1591- get_columns_as_strings (exclude_columns , allow_simple_expressions_only = True ) if exclude_columns else []
1670+ abs_tolerance , rel_tolerance = _extract_tolerances (abs_tolerance , rel_tolerance )
1671+ pk_column_names , ref_pk_column_names , exclude_column_names , check_alias = _prepare_column_names (
1672+ columns , ref_columns , exclude_columns
15921673 )
1593- check_alias = normalize_col_str (f"datasets_diff_pk_{ '_' .join (pk_column_names )} _ref_{ '_' .join (ref_pk_column_names )} " )
15941674
1675+ enable_llm_detection , _llm_options = _initialize_llm_options (llm_matching_key_detection_options )
15951676 unique_id = uuid .uuid4 ().hex
1596- condition_col = f"__compare_status_{ unique_id } "
1597- row_missing_col = f"__row_missing_{ unique_id } "
1598- row_extra_col = f"__row_extra_{ unique_id } "
1599- columns_changed_col = f"__columns_changed_{ unique_id } "
1600- filter_col = f"__filter_{ uuid .uuid4 ().hex } "
1677+ internal_column_names = _create_temp_column_names (unique_id )
16011678
16021679 def apply (df : DataFrame , spark : SparkSession , ref_dfs : dict [str , DataFrame ]) -> DataFrame :
16031680 nonlocal columns , ref_columns
16041681
16051682 ref_df = _get_ref_df (ref_df_name , ref_table , ref_dfs , spark )
16061683
1607- # Auto-detect matching keys using LLM if enabled
1608- if enable_llm_matching_key_detection :
1609- detected_columns = _detect_matching_keys_with_llm (
1610- df , columns , _llm_matching_key_detection_options , spark , unique_id
1611- )
1684+ if enable_llm_detection :
1685+ detected_columns = _detect_matching_keys_with_llm (df , columns , _llm_options , spark , unique_id )
16121686 columns = detected_columns # type: ignore[assignment]
16131687 ref_columns = detected_columns # type: ignore[assignment]
16141688
1615- # map type columns must be skipped as they cannot be compared with eqNullSafe
1616- map_type_columns = {field .name for field in df .schema .fields if isinstance (field .dataType , types .MapType )}
1617-
1618- # columns to compare: present in both df and ref_df, not in PK, not excluded, not map type
1619- compare_columns = [
1620- col
1621- for col in df .columns
1622- if (
1623- col in ref_df .columns
1624- and col not in pk_column_names
1625- and col not in exclude_column_names
1626- and col not in map_type_columns
1627- )
1628- ]
1629-
1630- # determine skipped columns: present in df, not compared, and not PK
1631- skipped_columns = [col for col in df .columns if col not in compare_columns and col not in pk_column_names ]
1632-
1633- # apply filter before aliasing to avoid ambiguity
1634- df = df .withColumn (filter_col , F .expr (row_filter ) if row_filter else F .lit (True ))
1635-
1689+ compare_columns , skipped_columns = _prepare_compare_columns (df , ref_df , pk_column_names , exclude_column_names )
1690+ df = df .withColumn (internal_column_names ["filter" ], F .expr (row_filter ) if row_filter else F .lit (True ))
16361691 df = df .alias ("df" )
16371692 ref_df = ref_df .alias ("ref_df" )
16381693
16391694 results = _match_rows (
16401695 df , ref_df , pk_column_names , ref_pk_column_names , check_missing_records , null_safe_row_matching
16411696 )
1642- results = _add_row_diffs (results , pk_column_names , ref_pk_column_names , row_missing_col , row_extra_col )
1697+ results = _add_row_diffs (
1698+ results ,
1699+ pk_column_names ,
1700+ ref_pk_column_names ,
1701+ internal_column_names ["row_missing" ],
1702+ internal_column_names ["row_extra" ],
1703+ )
16431704 results = _add_column_diffs (
1644- results , compare_columns , columns_changed_col , null_safe_column_value_matching , abs_tolerance , rel_tolerance
1705+ results ,
1706+ compare_columns ,
1707+ internal_column_names ["columns_changed" ],
1708+ null_safe_column_value_matching ,
1709+ abs_tolerance ,
1710+ rel_tolerance ,
16451711 )
16461712 results = _add_compare_condition (
1647- results , condition_col , row_missing_col , row_extra_col , columns_changed_col , filter_col
1713+ results ,
1714+ internal_column_names ["condition" ],
1715+ internal_column_names ["row_missing" ],
1716+ internal_column_names ["row_extra" ],
1717+ internal_column_names ["columns_changed" ],
1718+ internal_column_names ["filter" ],
16481719 )
16491720
1650- # in a full outer join, rows may be missing from either side, we take the first non-null value
1651- coalesced_pk_columns = [
1652- F .coalesce (F .col (f"df.{ col } " ), F .col (f"ref_df.{ ref_col } " )).alias (col )
1653- for col , ref_col in zip (pk_column_names , ref_pk_column_names )
1654- ]
1655-
1656- # make sure original columns + condition column are present in the output
1657- return results .select (
1658- * coalesced_pk_columns ,
1659- * [F .col (f"df.{ col } " ).alias (col ) for col in compare_columns ],
1660- * [F .col (f"df.{ col } " ).alias (col ) for col in skipped_columns ],
1661- F .col (condition_col ),
1721+ output_columns = _build_output_columns (
1722+ pk_column_names , ref_pk_column_names , compare_columns , skipped_columns , internal_column_names ["condition" ]
16621723 )
1724+ return results .select (* output_columns )
16631725
1664- condition = F .col (condition_col ).isNotNull ()
1726+ condition = F .col (internal_column_names [ "condition" ] ).isNotNull ()
16651727
16661728 return (
16671729 make_condition (
1668- condition = condition , message = F .when (condition , F .to_json (F .col (condition_col ))), alias = check_alias
1730+ condition = condition ,
1731+ message = F .when (condition , F .to_json (F .col (internal_column_names ["condition" ]))),
1732+ alias = check_alias ,
16691733 ),
16701734 apply ,
16711735 )
0 commit comments