-
Notifications
You must be signed in to change notification settings - Fork 35
Load and save checks from a Delta table #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a1c768e
2a8397e
a9b8e42
1cc2a91
81a9af0
a4ae4bd
c85e575
20f5e42
0e2cb04
c842bf5
0882b19
3ca046f
60a2b6e
88bb6e3
b3c9392
3bcfc93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,12 +3,13 @@ | |||||||||||||||||||||||
import functools as ft | ||||||||||||||||||||||||
import inspect | ||||||||||||||||||||||||
import itertools | ||||||||||||||||||||||||
import warnings | ||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||
from collections.abc import Callable | ||||||||||||||||||||||||
from typing import Any | ||||||||||||||||||||||||
import yaml | ||||||||||||||||||||||||
import pyspark.sql.functions as F | ||||||||||||||||||||||||
from pyspark.sql import DataFrame | ||||||||||||||||||||||||
from pyspark.sql import DataFrame, SparkSession | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from databricks.labs.blueprint.installation import Installation | ||||||||||||||||||||||||
from databricks.labs.dqx import row_checks | ||||||||||||||||||||||||
|
@@ -30,6 +31,7 @@ | |||||||||||||||||||||||
from databricks.sdk import WorkspaceClient | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||||
COLLECT_LIMIT_WARNING = 500 | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class DQEngineCore(DQEngineCoreBase): | ||||||||||||||||||||||||
|
@@ -142,6 +144,69 @@ def save_checks_in_local_file(checks: list[dict], filepath: str): | |||||||||||||||||||||||
msg = f"Checks file {filepath} missing" | ||||||||||||||||||||||||
raise FileNotFoundError(msg) from None | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def build_quality_rules_from_dataframe(df: DataFrame) -> list[dict]: | ||||||||||||||||||||||||
"""Build checks from a Spark DataFrame based on check specifications, i.e. function name plus arguments. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
:param df: Spark DataFrame with data quality check rules. Each row should define a check. Rows should | ||||||||||||||||||||||||
have the following columns: | ||||||||||||||||||||||||
* `name` - name that will be given to a resulting column. Autogenerated if not provided | ||||||||||||||||||||||||
* `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), | ||||||||||||||||||||||||
and `warn` (data is going into both dataframes) | ||||||||||||||||||||||||
* `check_function` - DQX check function used in the check | ||||||||||||||||||||||||
* `arguments` - Map of keyword arguments passed into the check function (e.g. `col_name`) | ||||||||||||||||||||||||
* `filter` - Expression for filtering data quality checks | ||||||||||||||||||||||||
:return: list of data quality check specifications as a Python dictionary | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
num_check_rows = df.count() | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling df.count() before df.collect() triggers two separate Spark jobs. Consider collecting rows once (e.g., rows = df.collect()) and using len(rows) to avoid the extra scan.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||
if num_check_rows > COLLECT_LIMIT_WARNING: | ||||||||||||||||||||||||
warnings.warn( | ||||||||||||||||||||||||
f"Collecting large number of rows from Spark DataFrame: {num_check_rows}", | ||||||||||||||||||||||||
category=UserWarning, | ||||||||||||||||||||||||
stacklevel=2, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
checks = [] | ||||||||||||||||||||||||
for row in df.collect(): | ||||||||||||||||||||||||
check = {"name": row.name, "criticality": row.criticality, "check": row.check.asDict()} | ||||||||||||||||||||||||
if row.filter is not None: | ||||||||||||||||||||||||
check["filter"] = row.filter | ||||||||||||||||||||||||
checks.append(check) | ||||||||||||||||||||||||
return checks | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def build_dataframe_from_quality_rules(checks: list[dict], spark: SparkSession | None = None) -> DataFrame: | ||||||||||||||||||||||||
"""Build a Spark DataFrame from a set of check specifications, i.e. function name plus arguments. | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
:param checks: list of check specifications as Python dictionaries. Each check consists of the following fields: | ||||||||||||||||||||||||
* `check` - Column expression to evaluate. This expression should return string value if it's evaluated to true - | ||||||||||||||||||||||||
it will be used as an error/warning message, or `null` if it's evaluated to `false` | ||||||||||||||||||||||||
* `name` - name that will be given to a resulting column. Autogenerated if not provided | ||||||||||||||||||||||||
* `criticality` (optional) - possible values are `error` (data going only into "bad" dataframe), | ||||||||||||||||||||||||
and `warn` (data is going into both dataframes) | ||||||||||||||||||||||||
:param spark: Optional SparkSession to use for DataFrame operations | ||||||||||||||||||||||||
:return: Spark DataFrame with data quality check rules | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
if spark is None: | ||||||||||||||||||||||||
spark = SparkSession.builder.getOrCreate() | ||||||||||||||||||||||||
schema = "name STRING, criticality STRING, check STRUCT<function STRING, arguments MAP<STRING, STRING>>, filter STRING" | ||||||||||||||||||||||||
dq_rule_checks = DQEngineCore.build_checks_by_metadata(checks) | ||||||||||||||||||||||||
dq_rule_rows = [] | ||||||||||||||||||||||||
for dq_rule_check in dq_rule_checks: | ||||||||||||||||||||||||
arguments = dq_rule_check.check_func_kwargs | ||||||||||||||||||||||||
if isinstance(dq_rule_check, DQColSetRule): | ||||||||||||||||||||||||
arguments["col_names"] = dq_rule_check.columns | ||||||||||||||||||||||||
if isinstance(dq_rule_check, DQColRule): | ||||||||||||||||||||||||
arguments["col_name"] = dq_rule_check.col_name | ||||||||||||||||||||||||
dq_rule_rows.append( | ||||||||||||||||||||||||
[ | ||||||||||||||||||||||||
dq_rule_check.name, | ||||||||||||||||||||||||
dq_rule_check.criticality, | ||||||||||||||||||||||||
{"function": dq_rule_check.check_func.__name__, "arguments": arguments}, | ||||||||||||||||||||||||
dq_rule_check.filter, | ||||||||||||||||||||||||
] | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
return spark.createDataFrame(dq_rule_rows, schema) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def build_checks_by_metadata(checks: list[dict], custom_checks: dict[str, Any] | None = None) -> list[DQColRule]: | ||||||||||||||||||||||||
"""Build checks based on check specification, i.e. function name plus arguments. | ||||||||||||||||||||||||
|
@@ -590,6 +655,20 @@ def load_checks_from_installation( | |||||||||||||||||||||||
raise ValueError(f"Invalid or no checks in workspace file: {installation.install_folder()}/{filename}") | ||||||||||||||||||||||||
return parsed_checks | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def load_checks_from_table(self, table_name: str, spark: SparkSession | None = None) -> list[dict]: | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
Load checks (dq rules) from a Delta table in the workspace. | ||||||||||||||||||||||||
:param table_name: Unity catalog or Hive metastore table name | ||||||||||||||||||||||||
:param spark: Optional SparkSession | ||||||||||||||||||||||||
:return: List of dq rules or raise an error if checks file is missing or is invalid. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
logger.info(f"Loading quality rules (checks) from table {table_name}") | ||||||||||||||||||||||||
if not self.ws.tables.exists(table_name).table_exists: | ||||||||||||||||||||||||
raise NotFound(f"Table {table_name} does not exist in the workspace") | ||||||||||||||||||||||||
if spark is None: | ||||||||||||||||||||||||
spark = SparkSession.builder.getOrCreate() | ||||||||||||||||||||||||
return DQEngine._load_checks_from_table(table_name, spark=spark) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def save_checks_in_local_file(checks: list[dict], path: str): | ||||||||||||||||||||||||
return DQEngineCore.save_checks_in_local_file(checks, path) | ||||||||||||||||||||||||
|
@@ -633,6 +712,17 @@ def save_checks_in_workspace_file(self, checks: list[dict], workspace_path: str) | |||||||||||||||||||||||
workspace_path, yaml.safe_dump(checks).encode('utf-8'), format=ImportFormat.AUTO, overwrite=True | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def save_checks_in_table(checks: list[dict], table_name: str, mode: str = "append"): | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
Save checks to a Delta table in the workspace. | ||||||||||||||||||||||||
:param checks: list of dq rules to save | ||||||||||||||||||||||||
:param table_name: Unity catalog or Hive metastore table name | ||||||||||||||||||||||||
:param mode: Output mode for writing checks to Delta (e.g. `append` or `overwrite`) | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
logger.info(f"Saving quality rules (checks) to table {table_name}") | ||||||||||||||||||||||||
DQEngine._save_checks_in_table(checks, table_name, mode) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def load_run_config( | ||||||||||||||||||||||||
self, run_config_name: str | None = "default", assume_user: bool = True, product_name: str = "dqx" | ||||||||||||||||||||||||
) -> RunConfig: | ||||||||||||||||||||||||
|
@@ -670,3 +760,15 @@ def _load_checks_from_file(installation: Installation, filename: str) -> list[di | |||||||||||||||||||||||
except NotFound: | ||||||||||||||||||||||||
msg = f"Checks file {filename} missing" | ||||||||||||||||||||||||
raise NotFound(msg) from None | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def _load_checks_from_table(table_name: str, spark: SparkSession | None = None) -> list[dict]: | ||||||||||||||||||||||||
if spark is None: | ||||||||||||||||||||||||
spark = SparkSession.builder.getOrCreate() | ||||||||||||||||||||||||
rules_df = spark.read.table(table_name) | ||||||||||||||||||||||||
return DQEngineCore.build_quality_rules_from_dataframe(rules_df) | ||||||||||||||||||||||||
Comment on lines
+766
to
+769
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] This static method duplicates functionality in
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||
|
||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||
def _save_checks_in_table(checks: list[dict], table_name: str, mode: str): | ||||||||||||||||||||||||
rules_df = DQEngineCore.build_dataframe_from_quality_rules(checks) | ||||||||||||||||||||||||
rules_df.write.saveAsTable(name=table_name, mode=mode) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import pytest | ||
from databricks.labs.dqx.engine import DQEngine | ||
from databricks.sdk.errors import NotFound | ||
|
||
|
||
TEST_CHECKS = [ | ||
{ | ||
"name": "column_is_not_null", | ||
"criticality": "error", | ||
"check": {"function": "is_not_null", "arguments": {"col_name": "col_1"}}, | ||
}, | ||
{ | ||
"name": "column_not_less_than", | ||
"criticality": "warn", | ||
"check": {"function": "is_not_less_than", "arguments": {"col_name": "col_1", "limit": "0"}}, | ||
}, | ||
] | ||
|
||
|
||
def test_load_checks_when_checks_table_does_not_exist(installation_ctx, make_schema, make_random, spark): | ||
client = installation_ctx.workspace_client | ||
catalog_name = "main" | ||
schema_name = make_schema(catalog_name=catalog_name).name | ||
table_name = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" | ||
|
||
with pytest.raises(NotFound, match=f"Table {table_name} does not exist in the workspace"): | ||
engine = DQEngine(client) | ||
engine.load_checks_from_table(table_name, spark) | ||
|
||
|
||
def test_load_checks_from_table(installation_ctx, make_schema, make_random, spark): | ||
client = installation_ctx.workspace_client | ||
ghanse marked this conversation as resolved.
Show resolved
Hide resolved
|
||
catalog_name = "main" | ||
schema_name = make_schema(catalog_name=catalog_name).name | ||
table_name = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" | ||
|
||
engine = DQEngine(client) | ||
engine.save_checks_in_table(TEST_CHECKS, table_name) | ||
checks = engine.load_checks_from_table(table_name, spark) | ||
assert checks == TEST_CHECKS, "Checks were not loaded correctly" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring refers to a
check_function
column, but the actual DataFrame schema uses acheck
column. Update the description to match the implementation.Copilot uses AI. Check for mistakes.