Skip to content

Load and save checks from a Delta table #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ max-args = 10
max-positional-arguments=10

# Maximum number of attributes for a class (see R0902).
max-attributes = 15
max-attributes = 16

# Maximum number of boolean expressions in an if statement (see R0916).
max-bool-expr = 5
Expand Down
1 change: 1 addition & 0 deletions src/databricks/labs/dqx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class RunConfig:
output_table: str | None = None # output data table
quarantine_table: str | None = None # quarantined data table
checks_file: str | None = "checks.yml" # file containing quality rules / checks
checks_table: str | None = None # table containing quality rules / checks
profile_summary_stats_file: str | None = "profile_summary_stats.yml" # file containing profile summary statistics
override_clusters: dict[str, str] | None = None # cluster configuration for jobs
spark_conf: dict[str, str] | None = None # extra spark configs
Expand Down
104 changes: 103 additions & 1 deletion src/databricks/labs/dqx/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import functools as ft
import inspect
import itertools
import warnings
from pathlib import Path
from collections.abc import Callable
from typing import Any
import yaml
import pyspark.sql.functions as F
from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession

from databricks.labs.blueprint.installation import Installation
from databricks.labs.dqx import row_checks
Expand All @@ -30,6 +31,7 @@
from databricks.sdk import WorkspaceClient

logger = logging.getLogger(__name__)
COLLECT_LIMIT_WARNING = 500


class DQEngineCore(DQEngineCoreBase):
Expand Down Expand Up @@ -142,6 +144,69 @@ def save_checks_in_local_file(checks: list[dict], filepath: str):
msg = f"Checks file {filepath} missing"
raise FileNotFoundError(msg) from None

@staticmethod
def build_quality_rules_from_dataframe(df: DataFrame) -> list[dict]:
"""Build checks from a Spark DataFrame based on check specifications, i.e. function name plus arguments.

:param df: Spark DataFrame with data quality check rules. Each row should define a check. Rows should
have the following columns:
* `name` - name that will be given to a resulting column. Autogenerated if not provided
* `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe),
and `warn` (data is going into both dataframes)
* `check_function` - DQX check function used in the check
Copy link
Preview

Copilot AI May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring refers to a check_function column, but the actual DataFrame schema uses a check column. Update the description to match the implementation.

Suggested change
* `check_function` - DQX check function used in the check
* `check` - DQX check function used in the check

Copilot uses AI. Check for mistakes.

* `arguments` - Map of keyword arguments passed into the check function (e.g. `col_name`)
* `filter` - Expression for filtering data quality checks
:return: list of data quality check specifications as a Python dictionary
"""
num_check_rows = df.count()
Copy link
Preview

Copilot AI May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling df.count() before df.collect() triggers two separate Spark jobs. Consider collecting rows once (e.g., rows = df.collect()) and using len(rows) to avoid the extra scan.

Suggested change
num_check_rows = df.count()
rows = df.collect()
num_check_rows = len(rows)

Copilot uses AI. Check for mistakes.

if num_check_rows > COLLECT_LIMIT_WARNING:
warnings.warn(
f"Collecting large number of rows from Spark DataFrame: {num_check_rows}",
category=UserWarning,
stacklevel=2,
)
checks = []
for row in df.collect():
check = {"name": row.name, "criticality": row.criticality, "check": row.check.asDict()}
if row.filter is not None:
check["filter"] = row.filter
checks.append(check)
return checks

@staticmethod
def build_dataframe_from_quality_rules(checks: list[dict], spark: SparkSession | None = None) -> DataFrame:
"""Build a Spark DataFrame from a set of check specifications, i.e. function name plus arguments.

:param checks: list of check specifications as Python dictionaries. Each check consists of the following fields:
* `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true -
it will be used as an error/warning message, or `null` if it's evaluated to `false`
* `name` - name that will be given to a resulting column. Autogenerated if not provided
* `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe),
and `warn` (data is going into both dataframes)
:param spark: Optional SparkSession to use for DataFrame operations
:return: Spark DataFrame with data quality check rules
"""
if spark is None:
spark = SparkSession.builder.getOrCreate()
schema = "name STRING, criticality STRING, check STRUCT<function STRING, arguments MAP<STRING, STRING>>, filter STRING"
dq_rule_checks = DQEngineCore.build_checks_by_metadata(checks)
dq_rule_rows = []
for dq_rule_check in dq_rule_checks:
arguments = dq_rule_check.check_func_kwargs
if isinstance(dq_rule_check, DQColSetRule):
arguments["col_names"] = dq_rule_check.columns
if isinstance(dq_rule_check, DQColRule):
arguments["col_name"] = dq_rule_check.col_name
dq_rule_rows.append(
[
dq_rule_check.name,
dq_rule_check.criticality,
{"function": dq_rule_check.check_func.__name__, "arguments": arguments},
dq_rule_check.filter,
]
)
return spark.createDataFrame(dq_rule_rows, schema)

@staticmethod
def build_checks_by_metadata(checks: list[dict], custom_checks: dict[str, Any] | None = None) -> list[DQColRule]:
"""Build checks based on check specification, i.e. function name plus arguments.
Expand Down Expand Up @@ -590,6 +655,20 @@ def load_checks_from_installation(
raise ValueError(f"Invalid or no checks in workspace file: {installation.install_folder()}/{filename}")
return parsed_checks

def load_checks_from_table(self, table_name: str, spark: SparkSession | None = None) -> list[dict]:
"""
Load checks (dq rules) from a Delta table in the workspace.
:param table_name: Unity catalog or Hive metastore table name
:param spark: Optional SparkSession
:return: List of dq rules or raise an error if checks file is missing or is invalid.
"""
logger.info(f"Loading quality rules (checks) from table {table_name}")
if not self.ws.tables.exists(table_name).table_exists:
raise NotFound(f"Table {table_name} does not exist in the workspace")
if spark is None:
spark = SparkSession.builder.getOrCreate()
return DQEngine._load_checks_from_table(table_name, spark=spark)

@staticmethod
def save_checks_in_local_file(checks: list[dict], path: str):
return DQEngineCore.save_checks_in_local_file(checks, path)
Expand Down Expand Up @@ -633,6 +712,17 @@ def save_checks_in_workspace_file(self, checks: list[dict], workspace_path: str)
workspace_path, yaml.safe_dump(checks).encode('utf-8'), format=ImportFormat.AUTO, overwrite=True
)

@staticmethod
def save_checks_in_table(checks: list[dict], table_name: str, mode: str = "append"):
"""
Save checks to a Delta table in the workspace.
:param checks: list of dq rules to save
:param table_name: Unity catalog or Hive metastore table name
:param mode: Output mode for writing checks to Delta (e.g. `append` or `overwrite`)
"""
logger.info(f"Saving quality rules (checks) to table {table_name}")
DQEngine._save_checks_in_table(checks, table_name, mode)

def load_run_config(
self, run_config_name: str | None = "default", assume_user: bool = True, product_name: str = "dqx"
) -> RunConfig:
Expand Down Expand Up @@ -670,3 +760,15 @@ def _load_checks_from_file(installation: Installation, filename: str) -> list[di
except NotFound:
msg = f"Checks file {filename} missing"
raise NotFound(msg) from None

@staticmethod
def _load_checks_from_table(table_name: str, spark: SparkSession | None = None) -> list[dict]:
if spark is None:
spark = SparkSession.builder.getOrCreate()
rules_df = spark.read.table(table_name)
return DQEngineCore.build_quality_rules_from_dataframe(rules_df)
Comment on lines +766 to +769
Copy link
Preview

Copilot AI May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This static method duplicates functionality in DQEngineCore. Consider consolidating with the core implementation to reduce code duplication.

Suggested change
if spark is None:
spark = SparkSession.builder.getOrCreate()
rules_df = spark.read.table(table_name)
return DQEngineCore.build_quality_rules_from_dataframe(rules_df)
"""
Load checks from a Delta table in the workspace.
:param table_name: Unity catalog or Hive metastore table name
:param spark: Optional SparkSession instance
:return: List of quality rules (checks)
"""
return DQEngineCore.load_checks_from_table(table_name, spark)

Copilot uses AI. Check for mistakes.


@staticmethod
def _save_checks_in_table(checks: list[dict], table_name: str, mode: str):
rules_df = DQEngineCore.build_dataframe_from_quality_rules(checks)
rules_df.write.saveAsTable(name=table_name, mode=mode)
19 changes: 3 additions & 16 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,7 @@ def config(self) -> WorkspaceConfig:
class MockInstallationContext(MockRuntimeContext):
__test__ = False

def __init__(
self,
env_or_skip_fixture,
ws,
check_file,
):
def __init__(self, env_or_skip_fixture, ws, check_file):
super().__init__(env_or_skip_fixture, ws)
self.check_file = check_file

Expand Down Expand Up @@ -170,16 +165,8 @@ def workspace_installation(self) -> WorkspaceInstallation:


@pytest.fixture
def installation_ctx(
ws,
env_or_skip,
check_file="checks.yml",
) -> Generator[MockInstallationContext, None, None]:
ctx = MockInstallationContext(
env_or_skip,
ws,
check_file,
)
def installation_ctx(ws, env_or_skip, check_file="checks.yml") -> Generator[MockInstallationContext, None, None]:
ctx = MockInstallationContext(env_or_skip, ws, check_file)
yield ctx.replace(workspace_client=ws)
ctx.workspace_installation.uninstall()

Expand Down
40 changes: 40 additions & 0 deletions tests/integration/test_load_checks_from_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from databricks.labs.dqx.engine import DQEngine
from databricks.sdk.errors import NotFound


TEST_CHECKS = [
{
"name": "column_is_not_null",
"criticality": "error",
"check": {"function": "is_not_null", "arguments": {"col_name": "col_1"}},
},
{
"name": "column_not_less_than",
"criticality": "warn",
"check": {"function": "is_not_less_than", "arguments": {"col_name": "col_1", "limit": "0"}},
},
]


def test_load_checks_when_checks_table_does_not_exist(installation_ctx, make_schema, make_random, spark):
client = installation_ctx.workspace_client
catalog_name = "main"
schema_name = make_schema(catalog_name=catalog_name).name
table_name = f"{catalog_name}.{schema_name}.{make_random(6).lower()}"

with pytest.raises(NotFound, match=f"Table {table_name} does not exist in the workspace"):
engine = DQEngine(client)
engine.load_checks_from_table(table_name, spark)


def test_load_checks_from_table(installation_ctx, make_schema, make_random, spark):
client = installation_ctx.workspace_client
catalog_name = "main"
schema_name = make_schema(catalog_name=catalog_name).name
table_name = f"{catalog_name}.{schema_name}.{make_random(6).lower()}"

engine = DQEngine(client)
engine.save_checks_in_table(TEST_CHECKS, table_name)
checks = engine.load_checks_from_table(table_name, spark)
assert checks == TEST_CHECKS, "Checks were not loaded correctly"
23 changes: 23 additions & 0 deletions tests/unit/test_build_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,26 @@ def test_deprecated_warning_dqrule_class():
def test_deprecated_warning_dqrulecolset_class():
with pytest.warns(DeprecationWarning, match="DQRuleColSet is deprecated and will be removed in a future version"):
DQRuleColSet(criticality="error", check_func=is_not_null, columns=["col1"])


def test_build_quality_rules_from_dataframe(spark_local):
test_checks = [
{
"name": "column_is_not_null",
"criticality": "error",
"check": {"function": "is_not_null", "arguments": {"col_name": "test_col"}},
},
{
"name": "column_is_not_null_or_empty",
"criticality": "warn",
"check": {"function": "is_not_null_and_not_empty", "arguments": {"col_name": "test_col"}},
},
{
"name": "column_not_less_than",
"criticality": "warn",
"check": {"function": "is_not_less_than", "arguments": {"col_name": "test_col", "limit": "5"}},
},
]
df = DQEngineCore.build_dataframe_from_quality_rules(test_checks, spark=spark_local)
checks = DQEngineCore.build_quality_rules_from_dataframe(df)
assert checks == test_checks, "The loaded checks do not match the expected checks."
Loading