diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 490c3d3..04d4e5e 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -8,9 +8,7 @@ on: jobs: build: - runs-on: ubuntu-latest - steps: - uses: actions/checkout@v2 - name: Set up JDK 8 @@ -23,6 +21,7 @@ jobs: - name: Coverage Report run: sbt coverageReport - name: "Upload coverage to Codecov" - uses: "codecov/codecov-action@v2" + uses: "codecov/codecov-action@v3" with: fail_ci_if_error: true + diff --git a/.gitignore b/.gitignore index 7f7ecc2..51a0286 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,10 @@ /.idea/ *.iml -#local spark context data from unit tests + +# Local spark context data from unit tests spark-warehouse/ -#Build dirctory for maven/sbt +# Build dirctory for maven/sbt target/ project/project/ project/target/ @@ -11,3 +12,8 @@ project/target/ /target/ /project/build.properties /src/main/scala/com/databricks/labs/validation/LocalTest.scala + +# Python Wrapper +/python/.idea/ +/python/dist/ + diff --git a/README.md b/README.md index 9a04923..dc5e8d1 100644 --- a/README.md +++ b/README.md @@ -328,6 +328,12 @@ evaluation specs and results The summary report is meant to be just that, a summary of the failed rules. This will return only the records that failed and only the rules that failed for that record; thus, if the `summaryReport.isEmpty` then all rules passed. + +## Python Wraper +The Python Wrapper allows users to validate data quality of their PySpark DataFrames using Python. + +They Python Wrapper can be found under the directory `/python`. A quickstart notebook is also located under `/python/examples`. + ## Next Steps Clearly, this is just a start. This is a small package and, as such, a GREAT place to start if you've never contributed to a project before. Please feel free to fork the repo and/or submit PRs. We'd love to see what diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000..296d1c6 --- /dev/null +++ b/python/README.md @@ -0,0 +1,55 @@ +# Python Connector for the DataFrame Rules Engine +The Python Connector allows users to validate data quality of their PySpark DataFrames using Python. + +```python +validation_results = RuleSet(df) + .add(myRules) + .validate() +``` + +Currently, the Python Connector supports the following Rule types: +1. List of Values (Strings _only_) +2. Boolean Check +3. User-defined Functions (must evaluate to a Boolean) + + +### Boolean Check +Validate that an column expression evaluates to True. +```python +# Ensure that the temperature is a valid reading +valid_temp_rule = Rule("valid_temperature", F.col("temperature") > -100.0) +``` + +### List of Values (LOVs) +Validate that a Column only contains values present in a List of Strings. + +```python +# Create a List of Strings (LOS) +building_sites = ["SiteA", "SiteB", "SiteC"] + +# Build a Rule that validates that a column only contains values from LOS +building_name_rule = Rule("Building_LOV_Rule", + column=F.col("site_name"), + valid_strings=building_sites) +``` + +### User-Defined Functions (UDFs) +UDFs are great when you need to add custom business logic for validating dataset quality. +You can use User-defined Functions with the DataFrame Rules Engine that return a Boolean value. + +```python +# Create a UDF to validate date entry +def valid_date_udf(ts_column): + return ts_column.isNotNull() & F.year(ts_column).isNotNull() & F.month(ts_column).isNotNull() + +# Create a Rule that uses the UDF to validate data +valid_date_rule = Rule("valid_date_reading", valid_date_udf(F.col("reading_date"))) +``` + +## Building the project + +A Python `.whl` file can be generated by navigating to the `/python` directory and executing the following command : + +```bash +$ python3 -m build +``` diff --git a/python/examples/01_generate_sample_purchase_transactions.py b/python/examples/01_generate_sample_purchase_transactions.py new file mode 100644 index 0000000..76797d8 --- /dev/null +++ b/python/examples/01_generate_sample_purchase_transactions.py @@ -0,0 +1,60 @@ +# Databricks notebook source +catalog_name = "REPLACE_ME" +schema_name = "REPLACE_ME" + +# COMMAND ---------- + +import random +import datetime + +def generate_sample_data(): + """Generates mock transaction data that randomly adds bad data""" + + # randomly generate bad data + if bool(random.getrandbits(1)): + appl_id = None + acct_no = None + event_ts = None + cstone_last_updatetm = None + else: + appl_id = random.randint(1000000, 9999999) + acct_no = random.randint(1000000000000000, 9999999999999999) + event_ts = datetime.datetime.now() + cstone_last_updatetm = datetime.datetime.now() + + # randomly generate an MCC description + categories = ["dining", "transportation", "merchendise", "hotels", "airfare", "grocery stores/supermarkets/bakeries"] + random_index = random.randint(0, len(categories)-1) + category = categories[random_index] + + # randomly generate a transaction price + price = round(random.uniform(1.99, 9999.99), 2) + + data = [ + (appl_id, acct_no, event_ts, category, price, cstone_last_updatetm) + ] + df = spark.createDataFrame(data, + "appl_id int, acct_no long, event_ts timestamp, category string, price float, cstone_last_updatetm timestamp") + return df + +# COMMAND ---------- + +spark.sql(f"create schema if not exists {catalog_name}.{schema_name}") + +# COMMAND ---------- + +spark.sql(f""" +CREATE TABLE IF NOT EXISTS {catalog_name}.{schema_name}.purchase_transactions_bronze +(appl_id int, acct_no long, event_ts timestamp, category string, price float, cstone_last_updatetm timestamp) +USING DELTA +TBLPROPERTIES (delta.enableChangeDataFeed = true) +""") + +# COMMAND ---------- + +df = generate_sample_data() +df.write.insertInto(f"{catalog_name}.{schema_name}.purchase_transactions_bronze") + +# COMMAND ---------- + + diff --git a/python/examples/02_apply_purchase_transaction_rules.py b/python/examples/02_apply_purchase_transaction_rules.py new file mode 100644 index 0000000..a101794 --- /dev/null +++ b/python/examples/02_apply_purchase_transaction_rules.py @@ -0,0 +1,152 @@ +# Databricks notebook source +# MAGIC %run ./PythonWrapper + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Ingest new Data + +# COMMAND ---------- + +import datetime + +starting_time = datetime.datetime.now() - datetime.timedelta(minutes=5) + +catalog_name = "REPLACE_ME" +schema_name = "REPLACE_ME" + +# COMMAND ---------- + +# Read table changes from 5 mins ago +df = spark.read.format("delta") \ + .option("readChangeFeed", "true") \ + .option("startingTimestamp", starting_time) \ + .table(f"{catalog_name}.{schema_name}.purchase_transactions_bronze") +purchase_transactions_df = df.select("appl_id", "acct_no", "event_ts", "category", "price", "cstone_last_updatetm")\ + .where("_change_type='insert'") +purchase_transactions_df.display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Define Rules using Builder Pattern + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Sample Rules +# MAGIC +# MAGIC From a DQ rule point of view, we would be looking at following scenarios: +# MAGIC +# MAGIC - **event_ts**: Should have a timestamp for every day (timestamp format doesn’t matter) +# MAGIC - **cstone_last_updatetm**: Should have a timestamp for every day +# MAGIC - **acct_no**: No null values for this column +# MAGIC - **appl_id**: No null values for this column +# MAGIC - **Changes in string length** - for all columns +# MAGIC + +# COMMAND ---------- + +import pyspark.sql.functions as F + +# First, begin by defining your RuleSet by passing in your input DataFrame +myRuleSet = RuleSet(purchase_transactions_df) + +# Rule 1 - define a Rule that validates that the `acct_no` is never null +acct_num_rule = Rule("valid_acct_no_rule", F.col("acct_no").isNotNull()) +myRuleSet.add(acct_num_rule) + +# Rule 2 - add a Rule that validates that the `appl_id` is never null +appl_id_rule = Rule("valid_appl_id", F.col("appl_id").isNotNull()) +myRuleSet.add(appl_id_rule) + +# COMMAND ---------- + +# Rules can even be used in conjunction with User-Defined Functions +def valid_timestamp(ts_column): + return ts_column.isNotNull() & F.year(ts_column).isNotNull() & F.month(ts_column).isNotNull() + +# COMMAND ---------- + +# Rule 3 - enforce a valid `event_ts` timestamp +valid_event_ts_rule = Rule("valid_event_ts_rule", valid_timestamp(F.col("event_ts"))) +myRuleSet.add(valid_event_ts_rule) + +# Rule 4 - enforce a valid `cstone_last_updatetm` timestamp +valid_cstone_last_updatetm_rule = Rule("valid_cstone_last_updatetm_rule", valid_timestamp(F.col("cstone_last_updatetm"))) +myRuleSet.add(valid_cstone_last_updatetm_rule) + +# COMMAND ---------- + +# Rule 5 - validate string lengths +import pyspark.sql.functions as F +import datetime + +starting_timestamp = datetime.datetime.now() - datetime.timedelta(minutes=5) +ending_timestamp = datetime.datetime.now() - datetime.timedelta(minutes=1) + +# Read table changes from 5 mins ago +df = spark.read.format("delta") \ + .option("readChangeFeed", "true") \ + .option("startingVersion", 0) \ + .option("endingVersion", 10) \ + .table(f"{catalog_name}.{schema_name}.purchase_transactions_bronze") +df_category = df.select("category").where("_change_type='insert'").agg(F.mean(F.length(F.col("category"))).alias("avg_category_len")) +avg_category_len = df_category.collect()[0]['avg_category_len'] +print(avg_category_len) + +# COMMAND ---------- + +def valid_category_len(category_column, avg_category_str_len): + return F.length(category_column) <= avg_category_str_len + +# Rule 5 - validate `category` string lengths +valid_str_length_rule = Rule("valid_category_str_length_rule", valid_category_len(F.col("category"), avg_category_len)) +myRuleSet.add(valid_str_length_rule) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Validate Rows + +# COMMAND ---------- + +from pyspark.sql import DataFrame + +# Finally, add the Rule to the RuleSet and validate! +summaryReport = myRuleSet.get_summary_report() +completeReport = myRuleSet.get_complete_report() + +# Display the summary validation report +display(summaryReport) + +# COMMAND ---------- + +# Display the complete validation report +display(completeReport) + +# COMMAND ---------- + +spark.sql(f""" + CREATE TABLE IF NOT EXISTS {catalog_name}.{schema_name}.purchase_transactions_validated + (appl_id int, acct_no long, event_ts timestamp, category string, price float, cstone_last_updatetm timestamp, failed_rules array) + USING DELTA + TBLPROPERTIES (delta.enableChangeDataFeed = true) +""") + +# COMMAND ---------- + +import pyspark.sql.functions as F +import pyspark.sql.types as T + +if summaryReport.count() > 0: + summaryReport.write.insertInto(f"{catalog_name}.{schema_name}.purchase_transactions_validated") +else: + string_array_type = T.ArrayType(T.StringType()) + purchase_transactions_df \ + .withColumn("failed_rules", F.array(F.array().cast(string_array_type))) \ + .write.insertInto(f"{catalog_name}.{schema_name}.purchase_transactions_validated") + +# COMMAND ---------- + + diff --git a/python/examples/PythonWrapper.py b/python/examples/PythonWrapper.py new file mode 100644 index 0000000..1bb96e1 --- /dev/null +++ b/python/examples/PythonWrapper.py @@ -0,0 +1,212 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Spark Singleton + +# COMMAND ---------- + +import pyspark +from pyspark.sql import SparkSession, DataFrame +from typing import List + + +class SparkSingleton: + """A singleton class which returns one Spark instance""" + __instance = None + + @classmethod + def get_instance(cls): + """Create a Spark instance. + :return: A Spark instance + """ + return (SparkSession.builder + .appName("DataFrame Rules Engine") + .getOrCreate()) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Rule Types + +# COMMAND ---------- + +class RuleType: + + ValidateExpr = "expr" + ValidateBounds = "bounds" + ValidateNumerics = "validNumerics" + ValidateStrings = "validStrings" + ValidateDateTime = "validDateTime" + ValidateComplex = "complex" + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Structures class + +# COMMAND ---------- + +class Bounds: + + def __init__(self, lower, upper, + lowerInclusive = False, + upperInclusive = False): + self.lower = lower + self.upper = upper + self.lowerInclusive = lowerInclusive + self.upperInclusive = upperInclusive + self._spark = SparkSingleton.get_instance() + self._jBounds = self._spark._jvm.com.databricks.labs.validation.utils.Structures.Bounds(lower, upper, lowerInclusive, upperInclusive) + + def validationLogic(self, col): + jCol = col._jc + return self._spark._jvm.com.databricks.labs.validation.utils.Structures.Bounds.validationLogic(jCol) + + +class MinMaxRuleDef: + + def __init__(self, + rule_name: str, + column: pyspark.sql.Column, + bounds: Bounds, + by: List[pyspark.sql.Column] = None): + self.rule_name = rule_name + self.column = column + self.bounds = bounds + self.by = by + + +class ValidationResults: + + def __init__(self, + complete_report: pyspark.sql.DataFrame, + summary_report: pyspark.sql.DataFrame): + self.complete_report = complete_report + self.summary_report = summary_report + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # Rule class + +# COMMAND ---------- + +class Rule: + """ + Definition of a rule + """ + def __init__(self, + rule_name: str, + column: pyspark.sql.Column, + boundaries: Bounds = None, + valid_expr: pyspark.sql.Column = None, + valid_strings: List[str] = None, + valid_numerics = None, + ignore_case: bool = False, + invert_match: bool = False): + + self._spark = SparkSingleton.get_instance() + self._column = column + self._boundaries = boundaries + self._valid_expr = valid_expr + self._valid_strings = valid_strings + self._valid_numerics = valid_numerics + self._is_implicit_bool = False + + # Determine the Rule type by parsing the input arguments + if valid_strings is not None and len(valid_strings) > 0: + j_valid_strings = Helpers.to_java_array(valid_strings, self._spark._sc) + self._jRule = self._spark._jvm.com.databricks.labs.validation.Rule.apply(rule_name, column._jc, + j_valid_strings, + ignore_case, invert_match) + self._rule_type = RuleType.ValidateStrings + + elif valid_numerics is not None and len(valid_numerics) > 0: + j_valid_numerics = Helpers.to_java_array(valid_numerics, self._spark._sc) + self._jRule = self._spark._jvm.com.databricks.labs.validation.Rule.apply(rule_name, + column._jc, + j_valid_numerics) + self._rule_type = RuleType.ValidateNumerics + else: + self._jRule = self._spark._jvm.com.databricks.labs.validation.Rule.apply(rule_name, column._jc) + self._is_implicit_bool = True + self._rule_type = RuleType.ValidateExpr + + def to_string(self): + return self._jRule.toString() + + def boundaries(self): + return self._boundaries + + def valid_numerics(self): + return self._valid_numerics + + def valid_strings(self): + return self._valid_strings + + def valid_expr(self): + return self._valid_expr + + def is_implicit_bool(self): + return self._jRule.implicitBoolean + + def ignore_case(self): + return self._jRule.ignoreCase + + def invert_match(self): + return self._jRule.invertMatch + + def rule_name(self): + return self._jRule.ruleName + + def is_agg(self): + return self._jRule.isAgg + + def input_column_name(self): + return self._jRule.inputColumnName + + def rule_type(self): + return self._rule_type + + def to_java(self): + return self._jRule + + +# COMMAND ---------- + +# MAGIC %md +# MAGIC # RuleSet class + +# COMMAND ---------- + +class RuleSet(): + + def __init__(self, df): + self.spark = SparkSingleton.get_instance() + self._df = df + self._jdf = df._jdf + self._jRuleSet = self.spark._jvm.com.databricks.labs.validation.RuleSet.apply(self._jdf) + + def add(self, rule): + self._jRuleSet.add(rule.to_java()) + + def get_df(self): + return self._df + + def to_java(self): + return self._jRuleSet + + def validate(self): + validation_results = self._jRuleSet.validate(1) + return validation_results + + def get_complete_report(self): + jCompleteReport = self._jRuleSet.getCompleteReport() + return DataFrame(jCompleteReport, self.spark._sc) + + def get_summary_report(self): + jSummaryReport = self._jRuleSet.getSummaryReport() + return DataFrame(jSummaryReport, self.spark._sc) + +# COMMAND ---------- + + diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 0000000..31df80d --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,31 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/databricks"] + +[project] +name = "dataframe-rules-engine" +version = "0.0.1" +description = "An extensible Rules Engine for custom Apache Spark Dataframe / Dataset validation." +authors = [ + { name="Daniel Tomes", email="daniel.tomes@databricks.com" }, + { name="Will Girten", email="will.girten@databricks.com" }, +] +keywords = ["Spark", "Rules", "Validation"] +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "pyspark" +] +classifiers = [ + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "License :: Other/Proprietary License", + "Operating System :: OS Independent", +] + +[project.urls] +Homepage = "https://github.com/databrickslabs/dataframe-rules-engine" +Issues = "https://github.com/databrickslabs/dataframe-rules-engine/issues" diff --git a/python/src/databricks/__init__.py b/python/src/databricks/__init__.py new file mode 100644 index 0000000..a2d7d14 --- /dev/null +++ b/python/src/databricks/__init__.py @@ -0,0 +1,3 @@ +from .labs.validation.rule import Rule, RuleType +from .labs.validation.rule_set import RuleSet +from .labs.validation.structures import * diff --git a/python/src/databricks/labs/__init__.py b/python/src/databricks/labs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/src/databricks/labs/validation/__init__.py b/python/src/databricks/labs/validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/src/databricks/labs/validation/local_spark_singleton.py b/python/src/databricks/labs/validation/local_spark_singleton.py new file mode 100644 index 0000000..1cb38c9 --- /dev/null +++ b/python/src/databricks/labs/validation/local_spark_singleton.py @@ -0,0 +1,8 @@ +from pyspark.sql import SparkSession + + +class SparkSingleton: + + @classmethod + def get_instance(cls): + return SparkSession.builder.getOrCreate() diff --git a/python/src/databricks/labs/validation/rule.py b/python/src/databricks/labs/validation/rule.py new file mode 100644 index 0000000..ce8e22d --- /dev/null +++ b/python/src/databricks/labs/validation/rule.py @@ -0,0 +1,91 @@ +import pyspark +from typing import List + +from databricks.labs.validation.local_spark_singleton import SparkSingleton +from databricks.labs.validation.rule_type import RuleType +from databricks.labs.validation.structures import Bounds +from databricks.labs.validation.utils.helpers import Helpers + + +class Rule: + """ + Definition of a rule + """ + # TODO: Fix type hint for valid_numerics + def __init__(self, + rule_name: str, + column: pyspark.sql.Column, + boundaries: Bounds = None, + valid_expr: pyspark.sql.Column = None, + valid_strings: List[str] = None, + valid_numerics = None, + ignore_case: bool = False, + invert_match: bool = False): + + self._spark = SparkSingleton.get_instance() + self._column = column + self._boundaries = boundaries + self._valid_expr = valid_expr + self._valid_strings = valid_strings + self._valid_numerics = valid_numerics + self._is_implicit_bool = False + self._rule_name = rule_name + + # Determine the Rule type by parsing the input arguments + if valid_strings is not None and len(valid_strings) > 0: + j_valid_strings = Helpers.to_java_array(valid_strings, self._spark._sc) + self._jRule = self._spark._jvm.com.databricks.labs.validation.Rule.apply(rule_name, column._jc, + j_valid_strings, + ignore_case, invert_match) + self._rule_type = RuleType.ValidateStrings + + elif valid_numerics is not None and len(valid_numerics) > 0: + j_valid_numerics = Helpers.to_java_array(valid_numerics, self._spark._sc) + print(j_valid_numerics) + self._jRule = self._spark._jvm.com.databricks.labs.validation.Rule.apply(rule_name, + column._jc, + j_valid_numerics) + self._rule_type = RuleType.ValidateNumerics + else: + self._jRule = self._spark._jvm.com.databricks.labs.validation.Rule.apply(rule_name, column._jc) + self._is_implicit_bool = True + self._rule_type = RuleType.ValidateExpr + + def to_string(self): + return self._jRule.toString() + + def boundaries(self): + return self._boundaries + + def valid_numerics(self): + return self._valid_numerics + + def valid_strings(self): + return self._valid_strings + + def valid_expr(self): + return self._valid_expr + + def is_implicit_bool(self): + return self._jRule.implicitBoolean + + def ignore_case(self): + return bool(self._jRule.ignoreCase) + + def invert_match(self): + return bool(self._jRule.invertMatch) + + def rule_name(self): + return self._rule_name + + def is_agg(self): + return self._jRule.isAgg + + def input_column_name(self): + return self._jRule.inputColumnName.toString() + + def rule_type(self): + return self._rule_type + + def to_java(self): + return self._jRule diff --git a/python/src/databricks/labs/validation/rule_set.py b/python/src/databricks/labs/validation/rule_set.py new file mode 100644 index 0000000..e5c7f0c --- /dev/null +++ b/python/src/databricks/labs/validation/rule_set.py @@ -0,0 +1,32 @@ +from pyspark.sql import DataFrame + +from databricks.labs.validation.local_spark_singleton import SparkSingleton +from databricks.labs.validation.structures import ValidationResults, MinMaxRuleDef + + +class RuleSet(): + + def __init__(self, df): + self.spark = SparkSingleton.get_instance() + self._df = df + self._jdf = df._jdf + self._jRuleSet = self.spark._jvm.com.databricks.labs.validation.RuleSet.apply(self._jdf) + + def add(self, rule): + self._jRuleSet.add(rule.to_java()) + + def addMinMaxRule(self, minMaxRuleDef): + pass + + def get_df(self): + return self._df + + def to_java(self): + return self._jRuleSet + + def validate(self): + jValidationResults = self._jRuleSet.validate(1) + complete_report = DataFrame(jValidationResults.completeReport(), self.spark) + summary_report = DataFrame(jValidationResults.summaryReport(), self.spark) + validation_results = ValidationResults(complete_report, summary_report) + return validation_results diff --git a/python/src/databricks/labs/validation/rule_type.py b/python/src/databricks/labs/validation/rule_type.py new file mode 100644 index 0000000..7642c12 --- /dev/null +++ b/python/src/databricks/labs/validation/rule_type.py @@ -0,0 +1,9 @@ + +class RuleType: + + ValidateExpr = "expr" + ValidateBounds = "bounds" + ValidateNumerics = "validNumerics" + ValidateStrings = "validStrings" + ValidateDateTime = "validDateTime" + ValidateComplex = "complex" diff --git a/python/src/databricks/labs/validation/structures.py b/python/src/databricks/labs/validation/structures.py new file mode 100644 index 0000000..191208e --- /dev/null +++ b/python/src/databricks/labs/validation/structures.py @@ -0,0 +1,51 @@ +import pyspark +from typing import List + +from databricks.labs.validation.local_spark_singleton import SparkSingleton + + +class Bounds: + + def __init__(self, lower, upper, + lowerInclusive=False, + upperInclusive=False): + self.lower = lower + self.upper = upper + self.lowerInclusive = lowerInclusive + self.upperInclusive = upperInclusive + self._spark = SparkSingleton.get_instance() + self._jBounds = self._spark._jvm.com.databricks.labs.validation.utils.Structures.Bounds(lower, upper, + lowerInclusive, + upperInclusive) + + def validationLogic(self, col): + jCol = col._jc + return self._spark._jvm.com.databricks.labs.validation.utils.Structures.Bounds.validationLogic(jCol) + + +class MinMaxRuleDef: + + def __init__(self, + rule_name: str, + column: pyspark.sql.Column, + bounds: Bounds, + by: List[pyspark.sql.Column] = None): + self.rule_name = rule_name + self.column = column + self.bounds = bounds + self.by = by + + +class ValidationResults: + + def __init__(self, + complete_report: pyspark.sql.DataFrame, + summary_report: pyspark.sql.DataFrame): + self.complete_report = complete_report + self.summary_report = summary_report + + def get_complete_report(self): + return self.complete_report + + def get_summary_report(self): + return self.summary_report diff --git a/python/src/databricks/labs/validation/utils/__init__.py b/python/src/databricks/labs/validation/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/src/databricks/labs/validation/utils/helpers.py b/python/src/databricks/labs/validation/utils/helpers.py new file mode 100644 index 0000000..c7fc646 --- /dev/null +++ b/python/src/databricks/labs/validation/utils/helpers.py @@ -0,0 +1,12 @@ +class Helpers: + + @staticmethod + def to_java_array(py_array, sc): + if isinstance(py_array[0], str): + java_string_class = sc._jvm.java.lang.String + java_array = sc._gateway.new_array(java_string_class, len(py_array)) + for i in range(len(py_array)): + java_array[i] = py_array[i] + else: + raise Exception("Only List of Strings is currently supported.") + return java_array diff --git a/python/tests/local_spark_singleton.py b/python/tests/local_spark_singleton.py new file mode 100644 index 0000000..4864612 --- /dev/null +++ b/python/tests/local_spark_singleton.py @@ -0,0 +1,34 @@ +from pyspark.sql import SparkSession +from pyspark import SparkConf +import os + + +class SparkSingleton: + """A singleton class which returns one Spark instance""" + __instance = None + + @classmethod + def get_instance(cls): + """Create a Spark instance. + :return: A Spark instance + """ + config = SparkConf().setAll([("spark.driver.extraClassPath", + os.environ["RULES_ENGINE_JAR"])]) + spark = (SparkSession.builder + .config(conf=config) + .appName("DataFrame Rules Engine") + .getOrCreate()) + spark.sparkContext.setLogLevel("ERROR") + return spark + + @classmethod + def get_local_instance(cls): + config = SparkConf().setAll([("spark.driver.extraClassPath", + os.environ["RULES_ENGINE_JAR"])]) + spark = (SparkSession.builder + .config(conf=config) + .master("local[*]") + .appName("DataFrame Rules Engine") + .getOrCreate()) + spark.sparkContext.setLogLevel("ERROR") + return spark diff --git a/python/tests/test_rule.py b/python/tests/test_rule.py new file mode 100644 index 0000000..e63fb0b --- /dev/null +++ b/python/tests/test_rule.py @@ -0,0 +1,30 @@ +import unittest +import pyspark.sql +import pyspark.sql.functions as F + +from src.databricks.labs.validation.rule import Rule +from src.databricks.labs.validation.rule_type import RuleType +from tests.local_spark_singleton import SparkSingleton + + +class TestRule(unittest.TestCase): + + def setUp(self): + self.spark = SparkSingleton.get_instance() + + + def test_string_lov_rule(self): + """Tests that a list of String values rule can be instantiated correctly.""" + + # Ensure that a rule with a list of valid strings can be validated + building_sites = ["SiteA", "SiteB", "SiteC"] + building_name_rule = Rule("Building_LOV_Rule", column=F.col("site_name"), + valid_strings=building_sites) + + # Ensure that all attributes are set correctly for Integers + assert building_name_rule.rule_name() == "Building_LOV_Rule", "Rule name is not set as expected." + assert building_name_rule.rule_type() == RuleType.ValidateStrings, "The rule type is not set as expected." + assert not building_name_rule.ignore_case() + + def tearDown(self): + self.spark.stop() diff --git a/python/tests/test_rule_set.py b/python/tests/test_rule_set.py new file mode 100644 index 0000000..0b328cc --- /dev/null +++ b/python/tests/test_rule_set.py @@ -0,0 +1,145 @@ +import unittest + +from src.databricks import RuleSet, Rule +from tests.local_spark_singleton import SparkSingleton + +import pyspark.sql.functions as F + + +def valid_date_udf(ts_column): + return ts_column.isNotNull() & F.year(ts_column).isNotNull() & F.month(ts_column).isNotNull() + + +class TestRuleSet(unittest.TestCase): + + def setUp(self): + self.spark = SparkSingleton.get_instance() + + def test_create_ruleset_from_dataframe(self): + test_data = [ + (1.0, 2.0, 3.0), + (4.0, 5.0, 6.0), + (7.0, 8.0, 9.0) + ] + test_df = self.spark.createDataFrame(test_data, schema="retail_price float, scan_price float, cost float") + test_rule_set = RuleSet(test_df) + + # Ensure that the RuleSet DataFrame is set properly + assert test_rule_set.get_df().exceptAll(test_df).count() == 0 + + def test_list_of_strings(self): + iot_readings = [ + (1001, "zone_a", 50.1), + (1002, "zone_b", 25.4), + (1003, "zone_c", None) + ] + valid_zones = ["zone_a", "zone_b", "zone_c", "zone_d"] + df = self.spark.createDataFrame(iot_readings).toDF("device_id", "zone_id", "temperature") + rule_set = RuleSet(df) + + # Add a list of strings + valid_zones_rule = Rule("valid_zones", F.col("zone_id"), valid_strings=valid_zones) + rule_set.add(valid_zones_rule) + + # Ensure that the summary report contains no failed rules + validation_summary = rule_set.validate().get_summary_report() + assert validation_summary.where(F.col("failed_rules").isNotNull()).count() == 0 + + # Add a row that _should_ fail + new_iot_reading = [ + (1004, "zone_z", 30.1) + ] + new_reading_df = self.spark.createDataFrame(new_iot_reading).toDF("device_id", "zone_id", "temperature") + combined_df = df.union(new_reading_df) + new_rule_set = RuleSet(combined_df) + new_rule_set.add(valid_zones_rule) + new_validation_summary = new_rule_set.validate().get_summary_report() + + # Ensure that the added reading should fail due to an invalid zone id string + assert new_validation_summary.where(F.col("failed_rules").isNotNull()).count() == 1 + + def test_list_of_numerics(self): + iot_readings = [ + (1001, "zone_a", 50.1), + (1002, "zone_b", 25.4), + (1003, "zone_c", None) + ] + valid_device_ids = [1001, 1002, 1003, 1004, 1005] + df = self.spark.createDataFrame(iot_readings).toDF("device_id", "zone_id", "temperature") + rule_set = RuleSet(df) + + # Add a list of numerical values + valid_device_ids_rule = Rule("valid_device_id", F.col("device_id"), valid_numerics=valid_device_ids) + rule_set.add(valid_device_ids_rule) + + # Ensure that the summary report contains no failed rules + validation_summary = rule_set.validate().get_summary_report() + assert validation_summary.where(F.col("failed_rules").isNotNull()).count() == 0 + + def test_boolean_rules(self): + iot_readings = [ + (1001, "zone_a", 50.1), + (1002, "zone_b", 25.4), + (1003, "zone_c", None) + ] + df = self.spark.createDataFrame(iot_readings).toDF("device_id", "zone_id", "temperature") + rule_set = RuleSet(df) + + # Add a rule that `device_id` is not null + not_null_rule = Rule("valid_device_id", F.col("device_id").isNotNull()) + rule_set.add(not_null_rule) + + # Add a rule that `temperature` is > -100.0 degrees + valid_temp_rule = Rule("valid_temp", F.col("temperature") > -100.0) + rule_set.add(valid_temp_rule) + + validation_summary = rule_set.validate().get_summary_report() + assert validation_summary.where(F.col("failed_rules").isNotNull()).count() == 0 + + def test_udf_rules(self): + iot_readings = [ + (1001, "zone_a", 50.1, "2024-04-25"), + (1002, "zone_b", 25.4, "2024-04-24"), + (1003, "zone_c", None, "2024-04-24") + ] + df = self.spark.createDataFrame(iot_readings).toDF("device_id", "zone_id", "temperature", "reading_date_str") + df = df.withColumn("reading_date", F.col("reading_date_str").cast("date")).drop("reading_date_str") + rule_set = RuleSet(df) + + # Ensure that UDFs can be used to validate data quality + valid_reading_date_rule = Rule("valid_reading_date", valid_date_udf(F.col("reading_date"))) + rule_set.add(valid_reading_date_rule) + + validation_summary = rule_set.validate().get_summary_report() + assert validation_summary.where(F.col("failed_rules").isNotNull()).count() == 0 + + def test_add_rules(self): + iot_readings = [ + (1001, "zone_a", 50.1), + (1002, "zone_b", 25.4), + (1003, "zone_c", None) + ] + df = self.spark.createDataFrame(iot_readings).toDF("device_id", "zone_id", "temperature") + rule_set = RuleSet(df) + + # Test boolean rule + temp_rule = Rule("valid_temp", F.col("temperature").isNotNull()) + rule_set.add(temp_rule) + + # Ensure that the RuleSet DF can be set/gotten correctly + rule_set_df = rule_set.get_df() + assert rule_set_df.count() == 3 + assert "device_id" in rule_set_df.columns + assert "zone_id" in rule_set_df.columns + assert "temperature" in rule_set_df.columns + + # Add a list of strings + valid_zones_rule = Rule("valid_zones", F.col("zone_id"), valid_strings=["zone_a", "zone_b", "zone_c"]) + rule_set.add(valid_zones_rule) + + # Ensure that the summary report contains failed rules + validation_summary = rule_set.validate().get_summary_report() + assert validation_summary.where(F.col("failed_rules").isNotNull()).count() == 1 + + def tearDown(self): + self.spark.stop() diff --git a/python/tests/test_structures.py b/python/tests/test_structures.py new file mode 100644 index 0000000..6879ff9 --- /dev/null +++ b/python/tests/test_structures.py @@ -0,0 +1,39 @@ +import unittest + +from src.databricks.labs.validation.structures import MinMaxRuleDef, Bounds +from tests.local_spark_singleton import SparkSingleton + +import pyspark.sql.functions as F + + +class TestStructures(unittest.TestCase): + + def setUp(self): + self.spark = SparkSingleton.get_instance() + + def test_get_returns(self): + + # Test Bounds + sku_price_bounds = Bounds(1.0, 1000.0) + assert sku_price_bounds.lower == 1.0 + assert sku_price_bounds.upper == 1000.0 + assert not sku_price_bounds.lowerInclusive + assert not sku_price_bounds.upperInclusive + sku_price_bounds_inclusive = Bounds(1.0, 1000.0, lowerInclusive=True, upperInclusive=True) + assert sku_price_bounds_inclusive.lowerInclusive + assert sku_price_bounds_inclusive.upperInclusive + + # Test MinMax Definitions + min_max_no_agg = MinMaxRuleDef("valid_sku_prices", F.col("sku_price"), bounds=sku_price_bounds) + assert min_max_no_agg.rule_name == "valid_sku_prices", "Invalid rule name for MinMax definition." + assert min_max_no_agg.bounds.lower == 1.0 + assert min_max_no_agg.bounds.upper == 1000.0 + + min_max_w_agg = MinMaxRuleDef("valid_sku_prices_agg", F.col("sku_price"), bounds=sku_price_bounds, + by=[F.col("store_id"), F.col("product_id")]) + assert min_max_w_agg.rule_name == "valid_sku_prices_agg", "Invalid rule name for MinMax definition!" + assert min_max_w_agg.bounds.lower == 1.0 + assert min_max_w_agg.bounds.upper == 1000.0 + + def tearDown(self): + self.spark.stop()