diff --git a/pyproject.toml b/pyproject.toml index 2cfcfc6..d5c32d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -413,7 +413,7 @@ max-args = 10 max-positional-arguments=10 # Maximum number of attributes for a class (see R0902). -max-attributes = 15 +max-attributes = 16 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr = 5 diff --git a/src/databricks/labs/dqx/config.py b/src/databricks/labs/dqx/config.py index c59476f..be3bfc7 100644 --- a/src/databricks/labs/dqx/config.py +++ b/src/databricks/labs/dqx/config.py @@ -17,6 +17,7 @@ class RunConfig: output_table: str | None = None # output data table quarantine_table: str | None = None # quarantined data table checks_file: str | None = "checks.yml" # file containing quality rules / checks + checks_table: str | None = None # table containing quality rules / checks profile_summary_stats_file: str | None = "profile_summary_stats.yml" # file containing profile summary statistics override_clusters: dict[str, str] | None = None # cluster configuration for jobs spark_conf: dict[str, str] | None = None # extra spark configs diff --git a/src/databricks/labs/dqx/engine.py b/src/databricks/labs/dqx/engine.py index b9d3c22..8411352 100644 --- a/src/databricks/labs/dqx/engine.py +++ b/src/databricks/labs/dqx/engine.py @@ -3,12 +3,13 @@ import functools as ft import inspect import itertools +import warnings from pathlib import Path from collections.abc import Callable from typing import Any import yaml import pyspark.sql.functions as F -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession from databricks.labs.blueprint.installation import Installation from databricks.labs.dqx import row_checks @@ -30,6 +31,7 @@ from databricks.sdk import WorkspaceClient logger = logging.getLogger(__name__) +COLLECT_LIMIT_WARNING = 500 class DQEngineCore(DQEngineCoreBase): @@ -142,6 +144,69 @@ def save_checks_in_local_file(checks: list[dict], filepath: str): msg = f"Checks file {filepath} missing" raise FileNotFoundError(msg) from None + @staticmethod + def build_quality_rules_from_dataframe(df: DataFrame) -> list[dict]: + """Build checks from a Spark DataFrame based on check specifications, i.e. function name plus arguments. + + :param df: Spark DataFrame with data quality check rules. Each row should define a check. Rows should + have the following columns: + * `name` - name that will be given to a resulting column. Autogenerated if not provided + * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), + and `warn` (data is going into both dataframes) + * `check_function` - DQX check function used in the check + * `arguments` - Map of keyword arguments passed into the check function (e.g. `col_name`) + * `filter` - Expression for filtering data quality checks + :return: list of data quality check specifications as a Python dictionary + """ + num_check_rows = df.count() + if num_check_rows > COLLECT_LIMIT_WARNING: + warnings.warn( + f"Collecting large number of rows from Spark DataFrame: {num_check_rows}", + category=UserWarning, + stacklevel=2, + ) + checks = [] + for row in df.collect(): + check = {"name": row.name, "criticality": row.criticality, "check": row.check.asDict()} + if row.filter is not None: + check["filter"] = row.filter + checks.append(check) + return checks + + @staticmethod + def build_dataframe_from_quality_rules(checks: list[dict], spark: SparkSession | None = None) -> DataFrame: + """Build a Spark DataFrame from a set of check specifications, i.e. function name plus arguments. + + :param checks: list of check specifications as Python dictionaries. Each check consists of the following fields: + * `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - + it will be used as an error/warning message, or `null` if it's evaluated to `false` + * `name` - name that will be given to a resulting column. Autogenerated if not provided + * `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), + and `warn` (data is going into both dataframes) + :param spark: Optional SparkSession to use for DataFrame operations + :return: Spark DataFrame with data quality check rules + """ + if spark is None: + spark = SparkSession.builder.getOrCreate() + schema = "name STRING, criticality STRING, check STRUCT>, filter STRING" + dq_rule_checks = DQEngineCore.build_checks_by_metadata(checks) + dq_rule_rows = [] + for dq_rule_check in dq_rule_checks: + arguments = dq_rule_check.check_func_kwargs + if isinstance(dq_rule_check, DQColSetRule): + arguments["col_names"] = dq_rule_check.columns + if isinstance(dq_rule_check, DQColRule): + arguments["col_name"] = dq_rule_check.col_name + dq_rule_rows.append( + [ + dq_rule_check.name, + dq_rule_check.criticality, + {"function": dq_rule_check.check_func.__name__, "arguments": arguments}, + dq_rule_check.filter, + ] + ) + return spark.createDataFrame(dq_rule_rows, schema) + @staticmethod def build_checks_by_metadata(checks: list[dict], custom_checks: dict[str, Any] | None = None) -> list[DQColRule]: """Build checks based on check specification, i.e. function name plus arguments. @@ -590,6 +655,20 @@ def load_checks_from_installation( raise ValueError(f"Invalid or no checks in workspace file: {installation.install_folder()}/{filename}") return parsed_checks + def load_checks_from_table(self, table_name: str, spark: SparkSession | None = None) -> list[dict]: + """ + Load checks (dq rules) from a Delta table in the workspace. + :param table_name: Unity catalog or Hive metastore table name + :param spark: Optional SparkSession + :return: List of dq rules or raise an error if checks file is missing or is invalid. + """ + logger.info(f"Loading quality rules (checks) from table {table_name}") + if not self.ws.tables.exists(table_name).table_exists: + raise NotFound(f"Table {table_name} does not exist in the workspace") + if spark is None: + spark = SparkSession.builder.getOrCreate() + return DQEngine._load_checks_from_table(table_name, spark=spark) + @staticmethod def save_checks_in_local_file(checks: list[dict], path: str): return DQEngineCore.save_checks_in_local_file(checks, path) @@ -633,6 +712,17 @@ def save_checks_in_workspace_file(self, checks: list[dict], workspace_path: str) workspace_path, yaml.safe_dump(checks).encode('utf-8'), format=ImportFormat.AUTO, overwrite=True ) + @staticmethod + def save_checks_in_table(checks: list[dict], table_name: str, mode: str = "append"): + """ + Save checks to a Delta table in the workspace. + :param checks: list of dq rules to save + :param table_name: Unity catalog or Hive metastore table name + :param mode: Output mode for writing checks to Delta (e.g. `append` or `overwrite`) + """ + logger.info(f"Saving quality rules (checks) to table {table_name}") + DQEngine._save_checks_in_table(checks, table_name, mode) + def load_run_config( self, run_config_name: str | None = "default", assume_user: bool = True, product_name: str = "dqx" ) -> RunConfig: @@ -670,3 +760,15 @@ def _load_checks_from_file(installation: Installation, filename: str) -> list[di except NotFound: msg = f"Checks file {filename} missing" raise NotFound(msg) from None + + @staticmethod + def _load_checks_from_table(table_name: str, spark: SparkSession | None = None) -> list[dict]: + if spark is None: + spark = SparkSession.builder.getOrCreate() + rules_df = spark.read.table(table_name) + return DQEngineCore.build_quality_rules_from_dataframe(rules_df) + + @staticmethod + def _save_checks_in_table(checks: list[dict], table_name: str, mode: str): + rules_df = DQEngineCore.build_dataframe_from_quality_rules(checks) + rules_df.write.saveAsTable(name=table_name, mode=mode) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a21ddbf..a995bc8 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -80,12 +80,7 @@ def config(self) -> WorkspaceConfig: class MockInstallationContext(MockRuntimeContext): __test__ = False - def __init__( - self, - env_or_skip_fixture, - ws, - check_file, - ): + def __init__(self, env_or_skip_fixture, ws, check_file): super().__init__(env_or_skip_fixture, ws) self.check_file = check_file @@ -170,16 +165,8 @@ def workspace_installation(self) -> WorkspaceInstallation: @pytest.fixture -def installation_ctx( - ws, - env_or_skip, - check_file="checks.yml", -) -> Generator[MockInstallationContext, None, None]: - ctx = MockInstallationContext( - env_or_skip, - ws, - check_file, - ) +def installation_ctx(ws, env_or_skip, check_file="checks.yml") -> Generator[MockInstallationContext, None, None]: + ctx = MockInstallationContext(env_or_skip, ws, check_file) yield ctx.replace(workspace_client=ws) ctx.workspace_installation.uninstall() diff --git a/tests/integration/test_load_checks_from_table.py b/tests/integration/test_load_checks_from_table.py new file mode 100644 index 0000000..da79cd5 --- /dev/null +++ b/tests/integration/test_load_checks_from_table.py @@ -0,0 +1,40 @@ +import pytest +from databricks.labs.dqx.engine import DQEngine +from databricks.sdk.errors import NotFound + + +TEST_CHECKS = [ + { + "name": "column_is_not_null", + "criticality": "error", + "check": {"function": "is_not_null", "arguments": {"col_name": "col_1"}}, + }, + { + "name": "column_not_less_than", + "criticality": "warn", + "check": {"function": "is_not_less_than", "arguments": {"col_name": "col_1", "limit": "0"}}, + }, +] + + +def test_load_checks_when_checks_table_does_not_exist(installation_ctx, make_schema, make_random, spark): + client = installation_ctx.workspace_client + catalog_name = "main" + schema_name = make_schema(catalog_name=catalog_name).name + table_name = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" + + with pytest.raises(NotFound, match=f"Table {table_name} does not exist in the workspace"): + engine = DQEngine(client) + engine.load_checks_from_table(table_name, spark) + + +def test_load_checks_from_table(installation_ctx, make_schema, make_random, spark): + client = installation_ctx.workspace_client + catalog_name = "main" + schema_name = make_schema(catalog_name=catalog_name).name + table_name = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" + + engine = DQEngine(client) + engine.save_checks_in_table(TEST_CHECKS, table_name) + checks = engine.load_checks_from_table(table_name, spark) + assert checks == TEST_CHECKS, "Checks were not loaded correctly" diff --git a/tests/unit/test_build_rules.py b/tests/unit/test_build_rules.py index 0d1d63f..4f5628e 100644 --- a/tests/unit/test_build_rules.py +++ b/tests/unit/test_build_rules.py @@ -441,3 +441,26 @@ def test_deprecated_warning_dqrule_class(): def test_deprecated_warning_dqrulecolset_class(): with pytest.warns(DeprecationWarning, match="DQRuleColSet is deprecated and will be removed in a future version"): DQRuleColSet(criticality="error", check_func=is_not_null, columns=["col1"]) + + +def test_build_quality_rules_from_dataframe(spark_local): + test_checks = [ + { + "name": "column_is_not_null", + "criticality": "error", + "check": {"function": "is_not_null", "arguments": {"col_name": "test_col"}}, + }, + { + "name": "column_is_not_null_or_empty", + "criticality": "warn", + "check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "test_col"}}, + }, + { + "name": "column_not_less_than", + "criticality": "warn", + "check": {"function": "is_not_less_than", "arguments": {"col_name": "test_col", "limit": "5"}}, + }, + ] + df = DQEngineCore.build_dataframe_from_quality_rules(test_checks, spark=spark_local) + checks = DQEngineCore.build_quality_rules_from_dataframe(df) + assert checks == test_checks, "The loaded checks do not match the expected checks."