From e725a9a1485fde96744981974335486e56858cef Mon Sep 17 00:00:00 2001 From: ilicmarkodb Date: Tue, 16 Sep 2025 17:53:07 +0200 Subject: [PATCH 1/2] temp --- .../src/main/java/io/delta/kernel/Scan.java | 2 +- .../java/io/delta/kernel/Transaction.java | 4 +- .../delta/kernel/expressions/Predicate.java | 46 +- .../io/delta/kernel/internal/ScanImpl.java | 8 +- .../kernel/internal/actions/AddFile.java | 2 +- .../GenerateIcebergCompatActionUtils.java | 2 +- .../skipping/DataSkippingPredicate.java | 59 +- .../internal/skipping/DataSkippingUtils.java | 111 +++- .../internal/skipping/PartitionPredicate.java | 22 + .../{util => skipping}/PartitionUtils.java | 25 +- .../internal/skipping/StatsSchemaHelper.java | 102 +++- .../kernel/internal/util/ExpressionUtils.java | 20 + .../kernel/internal/util/SchemaUtils.java | 2 +- .../io/delta/kernel/utils/PartitionUtils.java | 2 +- .../skipping/DataSkippingUtilsSuite.scala | 515 ++++++++++++++++++ .../PartitionUtilsSuite.scala | 4 +- .../skipping/StatsSchemaHelperSuite.scala | 432 +++++++++++++++ .../util/DataSkippingUtilsSuite.scala | 162 ------ .../io/delta/kernel/test/TestUtils.scala | 56 ++ .../DefaultExpressionEvaluator.java | 222 +++++++- .../expressions/ExpressionVisitor.java | 12 +- .../io/delta/kernel/defaults/ScanSuite.scala | 343 ++++++++++++ 22 files changed, 1889 insertions(+), 264 deletions(-) create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionPredicate.java rename kernel/kernel-api/src/main/java/io/delta/kernel/internal/{util => skipping}/PartitionUtils.java (96%) create mode 100644 kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala rename kernel/kernel-api/src/test/scala/io/delta/kernel/internal/{util => skipping}/PartitionUtilsSuite.scala (99%) create mode 100644 kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala delete mode 100644 kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala create mode 100644 kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java index fcbc7b773ae..643ee502704 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java @@ -29,8 +29,8 @@ import io.delta.kernel.internal.deletionvectors.DeletionVectorUtils; import io.delta.kernel.internal.deletionvectors.RoaringBitmapArray; import io.delta.kernel.internal.rowtracking.MaterializedRowTrackingColumn; +import io.delta.kernel.internal.skipping.PartitionUtils; import io.delta.kernel.internal.util.ColumnMapping.ColumnMappingMode; -import io.delta.kernel.internal.util.PartitionUtils; import io.delta.kernel.internal.util.Tuple2; import io.delta.kernel.types.MetadataColumnSpec; import io.delta.kernel.types.StructField; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java index 5c70786f81c..8093a1f0c7a 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java @@ -19,9 +19,9 @@ import static io.delta.kernel.internal.DeltaErrors.partitionColumnMissingInData; import static io.delta.kernel.internal.TransactionImpl.getStatisticsColumns; import static io.delta.kernel.internal.data.TransactionStateRow.*; +import static io.delta.kernel.internal.skipping.PartitionUtils.getTargetDirectory; +import static io.delta.kernel.internal.skipping.PartitionUtils.validateAndSanitizePartitionValues; import static io.delta.kernel.internal.util.ColumnMapping.blockIfColumnMappingEnabled; -import static io.delta.kernel.internal.util.PartitionUtils.getTargetDirectory; -import static io.delta.kernel.internal.util.PartitionUtils.validateAndSanitizePartitionValues; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import static io.delta.kernel.internal.util.SchemaUtils.findColIndex; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java index 6c273ec8fd0..2a64d697124 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/expressions/Predicate.java @@ -119,6 +119,7 @@ public class Predicate extends ScalarExpression { public Predicate(String name, List children) { super(name, children); + checkArguments(name, children); collationIdentifier = Optional.empty(); } @@ -184,8 +185,51 @@ public boolean equals(Object o) { return this.hashCode() == o.hashCode(); } + private void checkArguments(String name, List children) { + boolean isPredicateSupported = + CONSTANT_OPERATORS.contains(this.name) + || UNARY_OPERATORS.contains(this.name) + || BINARY_OPERATORS.contains(this.name); + if (!isPredicateSupported) { + throw new IllegalArgumentException( + String.format( + "Predicate operator '%s' is not supported. Supported operators are %s, %s and %s.", + name, CONSTANT_OPERATORS, UNARY_OPERATORS, BINARY_OPERATORS)); + } + + int expectedNumberOfChildren = + CONSTANT_OPERATORS.contains(this.name) + ? 0 + : UNARY_OPERATORS.contains(this.name) + ? 1 + : BINARY_OPERATORS.contains(this.name) ? 2 : -1; + if (children.size() != expectedNumberOfChildren) { + throw new IllegalArgumentException( + String.format( + "Invalid Predicate: operator '%s' requires %s children, but found %d.", + this.name, expectedNumberOfChildren, children.size())); + } + } + + private static final Set CONSTANT_OPERATORS = + Stream.of("ALWAYS_TRUE", "ALWAYS_FALSE").collect(Collectors.toSet()); + + private static final Set UNARY_OPERATORS = + Stream.of("NOT", "IS_NULL", "IS_NOT_NULL").collect(Collectors.toSet()); + private static final Set BINARY_OPERATORS = - Stream.of("<", "<=", ">", ">=", "=", "AND", "OR", "IS NOT DISTINCT FROM", "STARTS_WITH") + Stream.of( + "<", + "<=", + ">", + ">=", + "=", + "<>", + "AND", + "OR", + "IS NOT DISTINCT FROM", + "STARTS_WITH", + "LIKE") .collect(Collectors.toSet()); /** Operators that support collation-based string comparison. */ diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 0251cb28d5e..c7acb31eb3f 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -16,9 +16,9 @@ package io.delta.kernel.internal; import static io.delta.kernel.internal.DeltaErrors.wrapEngineException; +import static io.delta.kernel.internal.skipping.PartitionUtils.rewritePartitionPredicateOnCheckpointFileSchema; +import static io.delta.kernel.internal.skipping.PartitionUtils.rewritePartitionPredicateOnScanFileSchema; import static io.delta.kernel.internal.skipping.StatsSchemaHelper.getStatsSchema; -import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnCheckpointFileSchema; -import static io.delta.kernel.internal.util.PartitionUtils.rewritePartitionPredicateOnScanFileSchema; import static java.util.function.Function.identity; import static java.util.stream.Collectors.toMap; @@ -39,6 +39,7 @@ import io.delta.kernel.internal.rowtracking.RowTracking; import io.delta.kernel.internal.skipping.DataSkippingPredicate; import io.delta.kernel.internal.skipping.DataSkippingUtils; +import io.delta.kernel.internal.skipping.PartitionUtils; import io.delta.kernel.internal.util.*; import io.delta.kernel.metrics.ScanReport; import io.delta.kernel.metrics.SnapshotReport; @@ -366,7 +367,8 @@ private CloseableIterator applyDataSkipping( // pruning it after is much simpler StructType prunedStatsSchema = DataSkippingUtils.pruneStatsSchema( - getStatsSchema(metadata.getDataSchema()), dataSkippingFilter.getReferencedCols()); + getStatsSchema(metadata.getDataSchema(), dataSkippingFilter.getReferencedCollations()), + dataSkippingFilter.getReferencedCols()); // Skipping happens in two steps: // 1. The predicate produces false for any file whose stats prove we can safely skip it. A diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java index 1e20a508636..38ef34fb7c2 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/AddFile.java @@ -15,8 +15,8 @@ */ package io.delta.kernel.internal.actions; +import static io.delta.kernel.internal.skipping.PartitionUtils.serializePartitionMap; import static io.delta.kernel.internal.util.InternalUtils.relativizePath; -import static io.delta.kernel.internal.util.PartitionUtils.serializePartitionMap; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import static io.delta.kernel.internal.util.VectorUtils.toJavaMap; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/GenerateIcebergCompatActionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/GenerateIcebergCompatActionUtils.java index 01cd87291ec..437e5439607 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/GenerateIcebergCompatActionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/actions/GenerateIcebergCompatActionUtils.java @@ -15,8 +15,8 @@ */ package io.delta.kernel.internal.actions; +import static io.delta.kernel.internal.skipping.PartitionUtils.serializePartitionMap; import static io.delta.kernel.internal.util.InternalUtils.relativizePath; -import static io.delta.kernel.internal.util.PartitionUtils.serializePartitionMap; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import static io.delta.kernel.internal.util.Preconditions.checkState; import static java.util.Objects.requireNonNull; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java index b5af6d60d64..e4013508e42 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java @@ -18,6 +18,7 @@ import io.delta.kernel.expressions.Column; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.CollationIdentifier; import java.util.*; /** A {@link Predicate} with a set of columns referenced by the expression. */ @@ -26,6 +27,8 @@ public class DataSkippingPredicate extends Predicate { /** Set of {@link Column}s referenced by the predicate or any of its child expressions */ private final Set referencedCols; + private final Set collationIdentifiers; + /** * @param name the predicate name * @param children list of expressions that are input to this predicate. @@ -35,6 +38,17 @@ public class DataSkippingPredicate extends Predicate { DataSkippingPredicate(String name, List children, Set referencedCols) { super(name, children); this.referencedCols = Collections.unmodifiableSet(referencedCols); + this.collationIdentifiers = Collections.unmodifiableSet(new HashSet<>()); + } + + DataSkippingPredicate( + String name, + List children, + CollationIdentifier collationIdentifier, + Set referencedCols) { + super(name, children, collationIdentifier); + this.referencedCols = Collections.unmodifiableSet(referencedCols); + this.collationIdentifiers = Collections.singleton(collationIdentifier); } /** @@ -46,18 +60,45 @@ public class DataSkippingPredicate extends Predicate { * @param right right input to this predicate */ DataSkippingPredicate(String name, DataSkippingPredicate left, DataSkippingPredicate right) { - this( - name, - Arrays.asList(left, right), - new HashSet() { - { - addAll(left.getReferencedCols()); - addAll(right.getReferencedCols()); - } - }); + super(name, Arrays.asList(left, right)); + this.referencedCols = + Collections.unmodifiableSet( + new HashSet() { + { + addAll(left.getReferencedCols()); + addAll(right.getReferencedCols()); + } + }); + this.collationIdentifiers = + Collections.unmodifiableSet( + new HashSet() { + { + addAll(left.getReferencedCollations()); + addAll(right.getReferencedCollations()); + } + }); } public Set getReferencedCols() { return referencedCols; } + + public Set getReferencedCollations() { + return collationIdentifiers; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof DataSkippingPredicate)) return false; + if (!super.equals(o)) return false; + DataSkippingPredicate that = (DataSkippingPredicate) o; + return Objects.equals(referencedCols, that.referencedCols) + && Objects.equals(collationIdentifiers, that.collationIdentifiers); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), referencedCols, collationIdentifiers); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java index e127dc57c5b..dd49d094f7a 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java @@ -27,6 +27,7 @@ import io.delta.kernel.engine.Engine; import io.delta.kernel.expressions.*; import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.types.CollationIdentifier; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; import java.util.*; @@ -263,10 +264,15 @@ private static Optional constructDataSkippingFilter( if (left instanceof Column && right instanceof Literal) { Column leftCol = (Column) left; Literal rightLit = (Literal) right; - if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) + if (schemaHelper.isSkippingEligibleMinMaxColumn( + leftCol, dataFilters.getCollationIdentifier().isPresent()) && schemaHelper.isSkippingEligibleLiteral(rightLit)) { return constructComparatorDataSkippingFilters( - dataFilters.getName(), leftCol, rightLit, schemaHelper); + dataFilters.getName(), + leftCol, + rightLit, + schemaHelper, + dataFilters.getCollationIdentifier()); } } else if (right instanceof Column && left instanceof Literal) { return constructDataSkippingFilter(reverseComparatorFilter(dataFilters), schemaHelper); @@ -284,7 +290,11 @@ private static Optional constructDataSkippingFilter( /** Construct the skipping predicate for a given comparator */ private static Optional constructComparatorDataSkippingFilters( - String comparator, Column leftCol, Literal rightLit, StatsSchemaHelper schemaHelper) { + String comparator, + Column leftCol, + Literal rightLit, + StatsSchemaHelper schemaHelper, + Optional collationIdentifier) { switch (comparator.toUpperCase(Locale.ROOT)) { @@ -295,35 +305,54 @@ private static Optional constructComparatorDataSkippingFi new DataSkippingPredicate( "AND", constructBinaryDataSkippingPredicate( - "<=", schemaHelper.getMinColumn(leftCol), rightLit), + "<=", + schemaHelper.getMinColumn(leftCol, collationIdentifier), + rightLit, + collationIdentifier), constructBinaryDataSkippingPredicate( - ">=", schemaHelper.getMaxColumn(leftCol), rightLit))); + ">=", + schemaHelper.getMaxColumn(leftCol, collationIdentifier), + rightLit, + collationIdentifier))); // Match any file whose min is less than the requested upper bound. case "<": return Optional.of( constructBinaryDataSkippingPredicate( - "<", schemaHelper.getMinColumn(leftCol), rightLit)); + "<", + schemaHelper.getMinColumn(leftCol, collationIdentifier), + rightLit, + collationIdentifier)); // Match any file whose min is less than or equal to the requested upper bound case "<=": return Optional.of( constructBinaryDataSkippingPredicate( - "<=", schemaHelper.getMinColumn(leftCol), rightLit)); + "<=", + schemaHelper.getMinColumn(leftCol, collationIdentifier), + rightLit, + collationIdentifier)); // Match any file whose max is larger than the requested lower bound. case ">": return Optional.of( constructBinaryDataSkippingPredicate( - ">", schemaHelper.getMaxColumn(leftCol), rightLit)); + ">", + schemaHelper.getMaxColumn(leftCol, collationIdentifier), + rightLit, + collationIdentifier)); // Match any file whose max is larger than or equal to the requested lower bound. case ">=": return Optional.of( constructBinaryDataSkippingPredicate( - ">=", schemaHelper.getMaxColumn(leftCol), rightLit)); + ">=", + schemaHelper.getMaxColumn(leftCol, collationIdentifier), + rightLit, + collationIdentifier)); case "IS NOT DISTINCT FROM": - return constructDataSkippingFilter(rewriteEqualNullSafe(leftCol, rightLit), schemaHelper); + return constructDataSkippingFilter( + rewriteEqualNullSafe(leftCol, rightLit, collationIdentifier), schemaHelper); default: throw new IllegalArgumentException( String.format("Unsupported comparator expression %s", comparator)); @@ -336,11 +365,22 @@ private static Optional constructComparatorDataSkippingFi * Literal}. */ private static DataSkippingPredicate constructBinaryDataSkippingPredicate( - String exprName, Tuple2> colExpr, Literal lit) { + String exprName, + Tuple2> colExpr, + Literal lit, + Optional collationIdentifier) { Column column = colExpr._1; Expression adjColExpr = colExpr._2.isPresent() ? colExpr._2.get() : column; - return new DataSkippingPredicate( - exprName, Arrays.asList(adjColExpr, lit), Collections.singleton(column)); + if (collationIdentifier.isPresent()) { + return new DataSkippingPredicate( + exprName, + Arrays.asList(adjColExpr, lit), + collationIdentifier.get(), + Collections.singleton(column)); + } else { + return new DataSkippingPredicate( + exprName, Arrays.asList(adjColExpr, lit), Collections.singleton(column)); + } } private static final Map REVERSE_COMPARATORS = @@ -356,15 +396,17 @@ private static DataSkippingPredicate constructBinaryDataSkippingPredicate( }; private static Predicate reverseComparatorFilter(Predicate predicate) { - return new Predicate( + return createPredicate( REVERSE_COMPARATORS.get(predicate.getName().toUpperCase(Locale.ROOT)), getRight(predicate), - getLeft(predicate)); + getLeft(predicate), + predicate.getCollationIdentifier()); } /** Construct the skipping predicate for a NOT expression child if possible */ private static Optional constructNotDataSkippingFilters( Predicate childPredicate, StatsSchemaHelper schemaHelper) { + Optional collationIdentifier = childPredicate.getCollationIdentifier(); switch (childPredicate.getName().toUpperCase(Locale.ROOT)) { // Use deMorgan's law to push the NOT past the AND. This is safe even with SQL // tri-valued logic (see below), and is desirable because we cannot generally push @@ -423,29 +465,36 @@ private static Optional constructNotDataSkippingFilters( new DataSkippingPredicate( "OR", constructBinaryDataSkippingPredicate( - "<", schemaHelper.getMinColumn(leftColumn), rightLiteral), + "<", + schemaHelper.getMinColumn(leftColumn, collationIdentifier), + rightLiteral, + collationIdentifier), constructBinaryDataSkippingPredicate( - ">", schemaHelper.getMaxColumn(leftColumn), rightLiteral))); + ">", + schemaHelper.getMaxColumn(leftColumn, collationIdentifier), + rightLiteral, + collationIdentifier))); }); case "<": return constructDataSkippingFilter( - new Predicate(">=", childPredicate.getChildren()), schemaHelper); + createPredicate(">=", childPredicate.getChildren(), collationIdentifier), schemaHelper); case "<=": return constructDataSkippingFilter( - new Predicate(">", childPredicate.getChildren()), schemaHelper); + createPredicate(">", childPredicate.getChildren(), collationIdentifier), schemaHelper); case ">": return constructDataSkippingFilter( - new Predicate("<=", childPredicate.getChildren()), schemaHelper); + createPredicate("<=", childPredicate.getChildren(), collationIdentifier), schemaHelper); case ">=": return constructDataSkippingFilter( - new Predicate("<", childPredicate.getChildren()), schemaHelper); + createPredicate("<", childPredicate.getChildren(), collationIdentifier), schemaHelper); case "IS NOT DISTINCT FROM": return constructDataSkippingFiltersForNotEqual( childPredicate, schemaHelper, (leftColumn, rightLiteral) -> constructDataSkippingFilter( - new Predicate("NOT", rewriteEqualNullSafe(leftColumn, rightLiteral)), + new Predicate( + "NOT", rewriteEqualNullSafe(leftColumn, rightLiteral, collationIdentifier)), schemaHelper)); case "NOT": // Remove redundant pairs of NOT @@ -525,12 +574,15 @@ private static String[] appendArray(String[] arr, String appendElem) { * Rewrite `EqualNullSafe(a, NotNullLiteral)` as `And(IsNotNull(a), EqualTo(a, NotNullLiteral))` * and rewrite `EqualNullSafe(a, null)` as `IsNull(a)` */ - private static Predicate rewriteEqualNullSafe(Column leftCol, Literal rightLit) { + private static Predicate rewriteEqualNullSafe( + Column leftCol, Literal rightLit, Optional collationIdentifier) { if (rightLit.getValue() == null) { return new Predicate("IS_NULL", leftCol); } return new Predicate( - "AND", new Predicate("IS_NOT_NULL", leftCol), new Predicate("=", leftCol, rightLit)); + "AND", + new Predicate("IS_NOT_NULL", leftCol), + createPredicate("=", leftCol, rightLit, collationIdentifier)); } /** Helper method for building DataSkippingPredicate for NOT =/IS NOT DISTINCT FROM */ @@ -546,13 +598,20 @@ private static Optional constructDataSkippingFiltersForNo Expression rightChild = getRight(equalPredicate); if (rightChild instanceof Column && leftChild instanceof Literal) { return constructDataSkippingFilter( - new Predicate("NOT", new Predicate(equalPredicate.getName(), rightChild, leftChild)), + new Predicate( + "NOT", + createPredicate( + equalPredicate.getName(), + rightChild, + leftChild, + equalPredicate.getCollationIdentifier())), schemaHelper); } if (leftChild instanceof Column && rightChild instanceof Literal) { Column leftCol = (Column) leftChild; Literal rightLit = (Literal) rightChild; - if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) + if (schemaHelper.isSkippingEligibleMinMaxColumn( + leftCol, equalPredicate.getCollationIdentifier().isPresent()) && schemaHelper.isSkippingEligibleLiteral(rightLit)) { return buildDataSkippingPredicateFunc.apply(leftCol, rightLit); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionPredicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionPredicate.java new file mode 100644 index 00000000000..f1271518bf2 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionPredicate.java @@ -0,0 +1,22 @@ +package io.delta.kernel.internal.skipping; + +import io.delta.kernel.expressions.Expression; +import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.CollationIdentifier; +import java.util.List; + +/** + * A specialized {@link Predicate} used for partition pruning. This class does not add any new + * functionality beyond {@link Predicate}, but serves as a marker to indicate that the predicate is + * specifically intended for partition pruning operations. + */ +public class PartitionPredicate extends Predicate { + PartitionPredicate(String name, List children) { + super(name, children); + } + + PartitionPredicate( + String name, List children, CollationIdentifier collationIdentifier) { + super(name, children, collationIdentifier); + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionUtils.java similarity index 96% rename from kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java rename to kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionUtils.java index 607bdbfdafa..5136caa2214 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/PartitionUtils.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.delta.kernel.internal.util; +package io.delta.kernel.internal.skipping; import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; @@ -31,6 +31,10 @@ import io.delta.kernel.internal.InternalScanFileUtils; import io.delta.kernel.internal.annotation.VisibleForTesting; import io.delta.kernel.internal.fs.Path; +import io.delta.kernel.internal.util.ColumnMapping; +import io.delta.kernel.internal.util.InternalUtils; +import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.internal.util.VectorUtils; import io.delta.kernel.types.*; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; @@ -270,11 +274,12 @@ public static Tuple2 splitMetadataAndDataPredicates( */ public static Predicate rewritePartitionPredicateOnCheckpointFileSchema( Predicate predicate, Map partitionColNameToField) { - return new Predicate( + return createPartitionPredicate( predicate.getName(), predicate.getChildren().stream() .map(child -> rewriteColRefOnPartitionValuesParsed(child, partitionColNameToField)) - .collect(Collectors.toList())); + .collect(Collectors.toList()), + predicate.getCollationIdentifier()); } private static Expression rewriteColRefOnPartitionValuesParsed( @@ -319,11 +324,12 @@ private static Expression rewriteColRefOnPartitionValuesParsed( */ public static Predicate rewritePartitionPredicateOnScanFileSchema( Predicate predicate, Map partitionColMetadata) { - return new Predicate( + return createPartitionPredicate( predicate.getName(), predicate.getChildren().stream() .map(child -> rewritePartitionColumnRef(child, partitionColMetadata)) - .collect(Collectors.toList())); + .collect(Collectors.toList()), + predicate.getCollationIdentifier()); } private static Expression rewritePartitionColumnRef( @@ -616,4 +622,13 @@ private static String escapePartitionValue(String value) { } return escaped.toString(); } + + private static PartitionPredicate createPartitionPredicate( + String name, List children, Optional collationIdentifier) { + if (collationIdentifier.isPresent()) { + return new PartitionPredicate(name, children, collationIdentifier.get()); + } else { + return new PartitionPredicate(name, children); + } + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java index b97595e4a6b..fc44c72a530 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java @@ -49,22 +49,27 @@ public class StatsSchemaHelper { public static final String MAX = "maxValues"; public static final String NULL_COUNT = "nullCount"; public static final String TIGHT_BOUNDS = "tightBounds"; + public static final String STATS_WITH_COLLATION = "statsWithCollation"; /** * Returns true if the given literal is skipping-eligible. Delta tracks min/max stats for a * limited set of data types and only literals of those types are skipping eligible. */ public static boolean isSkippingEligibleLiteral(Literal literal) { - return isSkippingEligibleDataType(literal.getDataType()); + return isSkippingEligibleDataType(literal.getDataType(), false); } /** Returns true if the given data type is eligible for MIN/MAX data skipping. */ - public static boolean isSkippingEligibleDataType(DataType dataType) { - return SKIPPING_ELIGIBLE_TYPE_NAMES.contains(dataType.toString()) - || - // DecimalType is eligible but since its string includes scale + precision it needs to - // be matched separately - dataType instanceof DecimalType; + public static boolean isSkippingEligibleDataType(DataType dataType, boolean isCollatedSkipping) { + if (isCollatedSkipping) { + return dataType instanceof StringType; + } else { + return SKIPPING_ELIGIBLE_TYPE_NAMES.contains(dataType.toString()) + || + // DecimalType is eligible but since its string includes scale + precision it needs to + // be matched separately + dataType instanceof DecimalType; + } } /** @@ -101,10 +106,11 @@ public static boolean isSkippingEligibleDataType(DataType dataType) { * | |-- tightBounds: boolean (nullable = true) * */ - public static StructType getStatsSchema(StructType dataSchema) { + public static StructType getStatsSchema( + StructType dataSchema, Set collationIdentifiers) { StructType statsSchema = new StructType().add(NUM_RECORDS, LongType.LONG, true); - StructType minMaxStatsSchema = getMinMaxStatsSchema(dataSchema); + StructType minMaxStatsSchema = getMinMaxStatsSchema(dataSchema, false); if (minMaxStatsSchema.length() > 0) { statsSchema = statsSchema.add(MIN, minMaxStatsSchema, true).add(MAX, minMaxStatsSchema, true); } @@ -116,6 +122,11 @@ public static StructType getStatsSchema(StructType dataSchema) { statsSchema = statsSchema.add(TIGHT_BOUNDS, BooleanType.BOOLEAN, true); + StructType collatedMinMaxStatsSchema = getCollatedStatsSchema(dataSchema, collationIdentifiers); + if (collatedMinMaxStatsSchema.length() > 0) { + statsSchema = statsSchema.add(STATS_WITH_COLLATION, collatedMinMaxStatsSchema, true); + } + return statsSchema; } @@ -155,13 +166,15 @@ public StatsSchemaHelper(StructType dataSchema) { * @param column the logical column name. * @return a tuple of the MIN column and an optional adjustment expression. */ - public Tuple2> getMinColumn(Column column) { + public Tuple2> getMinColumn( + Column column, Optional collationIdentifier) { checkArgument( - isSkippingEligibleMinMaxColumn(column), - "%s is not a valid min column for data schema %s", + isSkippingEligibleMinMaxColumn(column, collationIdentifier.isPresent()), + "%s is not a valid min column%s for data schema %s", column, + collationIdentifier.isPresent() ? (" for collation " + collationIdentifier) : "", dataSchema); - return new Tuple2<>(getStatsColumn(column, MIN), Optional.empty()); + return new Tuple2<>(getStatsColumn(column, MIN, collationIdentifier), Optional.empty()); } /** @@ -172,14 +185,16 @@ public Tuple2> getMinColumn(Column column) { * @param column the logical column name. * @return a tuple of the MAX column and an optional adjustment expression. */ - public Tuple2> getMaxColumn(Column column) { + public Tuple2> getMaxColumn( + Column column, Optional collationIdentifier) { checkArgument( - isSkippingEligibleMinMaxColumn(column), - "%s is not a valid min column for data schema %s", + isSkippingEligibleMinMaxColumn(column, collationIdentifier.isPresent()), + "%s is not a valid min column%s for data schema %s", column, + collationIdentifier.isPresent() ? (" for collation " + collationIdentifier) : "", dataSchema); DataType dataType = logicalToDataType.get(column); - Column maxColumn = getStatsColumn(column, MAX); + Column maxColumn = getStatsColumn(column, MAX, collationIdentifier); // If this is a column of type Timestamp or TimestampNTZ // compensate for the truncation from microseconds to milliseconds @@ -206,7 +221,7 @@ public Column getNullCountColumn(Column column) { "%s is not a valid null_count column for data schema %s", column, dataSchema); - return getStatsColumn(column, NULL_COUNT); + return getStatsColumn(column, NULL_COUNT, Optional.empty()); } /** Returns the NUM_RECORDS column in the statistic schema */ @@ -218,9 +233,9 @@ public Column getNumRecordsColumn() { * Returns true if the given column is skipping-eligible using min/max statistics. This means the * column exists, is a leaf column, and is of a skipping-eligible data-type. */ - public boolean isSkippingEligibleMinMaxColumn(Column column) { + public boolean isSkippingEligibleMinMaxColumn(Column column, boolean isCollatedSkipping) { return logicalToDataType.containsKey(column) - && isSkippingEligibleDataType(logicalToDataType.get(column)); + && isSkippingEligibleDataType(logicalToDataType.get(column), isCollatedSkipping); } /** @@ -256,16 +271,17 @@ public boolean isSkippingEligibleNullCountColumn(Column column) { * 1) replace logical names with physical names 2) set nullable=true 3) only keep stats eligible * fields (i.e. don't include fields with isSkippingEligibleDataType=false) */ - private static StructType getMinMaxStatsSchema(StructType dataSchema) { + private static StructType getMinMaxStatsSchema( + StructType dataSchema, boolean isCollatedSkipping) { List fields = new ArrayList<>(); for (StructField field : dataSchema.fields()) { - if (isSkippingEligibleDataType(field.getDataType())) { + if (isSkippingEligibleDataType(field.getDataType(), isCollatedSkipping)) { fields.add(new StructField(getPhysicalName(field), field.getDataType(), true)); } else if (field.getDataType() instanceof StructType) { fields.add( new StructField( getPhysicalName(field), - getMinMaxStatsSchema((StructType) field.getDataType()), + getMinMaxStatsSchema((StructType) field.getDataType(), isCollatedSkipping), true)); } } @@ -293,6 +309,24 @@ private static StructType getNullCountSchema(StructType dataSchema) { return new StructType(fields); } + private static StructType getCollatedStatsSchema( + StructType dataSchema, Set collationIdentifiers) { + StructType statsWithCollation = new StructType(); + StructType collatedMinMaxStatsSchema = getMinMaxStatsSchema(dataSchema, true); + for (CollationIdentifier collationIdentifier : collationIdentifiers) { + if (collatedMinMaxStatsSchema.length() > 0) { + statsWithCollation = + statsWithCollation.add( + collationIdentifier.toString(), + new StructType() + .add(MIN, collatedMinMaxStatsSchema, true) + .add(MAX, collatedMinMaxStatsSchema, true), + true); + } + } + return statsWithCollation; + } + ////////////////////////////////////////////////////////////////////////////////// // Private class helpers ////////////////////////////////////////////////////////////////////////////////// @@ -301,13 +335,23 @@ private static StructType getNullCountSchema(StructType dataSchema) { * Given a logical column and a stats type returns the corresponding column in the statistics * schema */ - private Column getStatsColumn(Column column, String statType) { + private Column getStatsColumn( + Column column, String statType, Optional collationIdentifier) { checkArgument( logicalToPhysicalColumn.containsKey(column), "%s is not a valid leaf column for data schema: %s", column, dataSchema); - return getChildColumn(logicalToPhysicalColumn.get(column), statType); + Column physicalColumn = logicalToPhysicalColumn.get(column); + // Use binary stats if collation is not specified or if it is the default Spark collation. + if (collationIdentifier.isPresent() + && collationIdentifier.get() != CollationIdentifier.SPARK_UTF8_BINARY) { + return getChildColumn( + physicalColumn, + Arrays.asList(statType, collationIdentifier.get().toString(), STATS_WITH_COLLATION)); + } else { + return getChildColumn(physicalColumn, statType); + } } /** @@ -344,6 +388,14 @@ private static Column getChildColumn(Column column, String parentName) { return new Column(prependArray(column.getNames(), parentName)); } + /** Returns the provided column as a child column nested under {@code parentPath} */ + private static Column getChildColumn(Column column, List parentPath) { + for (String name : parentPath) {; + column = getChildColumn(column, name); + } + return column; + } + /** * Given an array {@code names} and a string element {@code preElem} return a new array with * {@code preElem} inserted at the beginning diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java index 2e82fffea4c..3a1dfd51cfd 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java @@ -19,7 +19,10 @@ import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.CollationIdentifier; +import java.util.Arrays; import java.util.List; +import java.util.Optional; public class ExpressionUtils { /** Return an expression cast as a predicate, throw an error if it is not a predicate */ @@ -51,4 +54,21 @@ public static Expression getUnaryChild(Expression expression) { children.size() == 1, "%s: expected one inputs, but got %s", expression, children.size()); return children.get(0); } + + public static Predicate createPredicate( + String name, List children, Optional collationIdentifier) { + if (collationIdentifier.isPresent()) { + return new Predicate(name, children, collationIdentifier.get()); + } else { + return new Predicate(name, children); + } + } + + public static Predicate createPredicate( + String name, + Expression left, + Expression right, + Optional collationIdentifier) { + return createPredicate(name, Arrays.asList(left, right), collationIdentifier); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java index 8ac5de10d13..3a6ffc1b469 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java @@ -281,7 +281,7 @@ public static List casePreservingEligibleClusterColumns( List nonSkippingEligibleColumns = physicalColumnsWithTypes.stream() - .filter(tuple -> !StatsSchemaHelper.isSkippingEligibleDataType(tuple._2)) + .filter(tuple -> !StatsSchemaHelper.isSkippingEligibleDataType(tuple._2, false)) .map(tuple -> tuple._1.toString() + " : " + tuple._2) .collect(Collectors.toList()); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/PartitionUtils.java index 1877100dc19..c889f5daa0f 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/utils/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/utils/PartitionUtils.java @@ -52,7 +52,7 @@ public static boolean partitionExists( final Set snapshotPartColNames = new HashSet<>(snapshot.getPartitionColumnNames()); - io.delta.kernel.internal.util.PartitionUtils.validatePredicateOnlyOnPartitionColumns( + io.delta.kernel.internal.skipping.PartitionUtils.validatePredicateOnlyOnPartitionColumns( partitionPredicate, snapshotPartColNames); final Scan scan = snapshot.getScanBuilder().withFilter(partitionPredicate).build(); diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala new file mode 100644 index 00000000000..a279687bee9 --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala @@ -0,0 +1,515 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal.skipping + +import java.util.Optional + +import scala.collection.JavaConverters._ + +import io.delta.kernel.expressions.{Column, Expression, Predicate} +import io.delta.kernel.internal.skipping.DataSkippingUtils.constructDataSkippingFilter +import io.delta.kernel.internal.skipping.StatsSchemaHelper.{MAX, MIN, STATS_WITH_COLLATION} +import io.delta.kernel.internal.util.ExpressionUtils.createPredicate +import io.delta.kernel.test.TestUtils +import io.delta.kernel.types._ +import io.delta.kernel.types.IntegerType.INTEGER + +import org.scalatest.funsuite.AnyFunSuite + +class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils { + + def dataSkippingPredicate( + operator: String, + children: Seq[Expression], + referencedColumns: Set[Column]): DataSkippingPredicate = { + new DataSkippingPredicate(operator, children.asJava, referencedColumns.asJava) + } + + def dataSkippingPredicate( + operator: String, + left: DataSkippingPredicate, + right: DataSkippingPredicate): DataSkippingPredicate = { + new DataSkippingPredicate(operator, left, right) + } + + def dataSkippingPredicateWithCollation( + operator: String, + children: Seq[Expression], + collation: CollationIdentifier, + referencedColumns: Set[Column]): DataSkippingPredicate = { + new DataSkippingPredicate(operator, children.asJava, collation, referencedColumns.asJava) + } + + private def collatedStatsCol( + collation: CollationIdentifier, + statName: String, + fieldName: String): Column = { + new Column(Array(STATS_WITH_COLLATION, collation.toString, statName, fieldName)) + } + + /* For struct type checks for equality based on field names & data type only */ + def compareDataTypeUnordered(type1: DataType, type2: DataType): Boolean = (type1, type2) match { + case (schema1: StructType, schema2: StructType) => + val fields1 = schema1.fields().asScala.sortBy(_.getName) + val fields2 = schema2.fields().asScala.sortBy(_.getName) + if (fields1.length != fields2.length) { + false + } else { + fields1.zip(fields2).forall { case (field1: StructField, field2: StructField) => + field1.getName == field2.getName && + compareDataTypeUnordered(field1.getDataType, field2.getDataType) + } + } + case _ => + type1 == type2 + } + + def checkPruneStatsSchema( + inputSchema: StructType, + referencedCols: Set[Column], + expectedSchema: StructType): Unit = { + val prunedSchema = DataSkippingUtils.pruneStatsSchema(inputSchema, referencedCols.asJava) + assert( + compareDataTypeUnordered(expectedSchema, prunedSchema), + s"expected=$expectedSchema\nfound=$prunedSchema") + } + + test("pruneStatsSchema - multiple basic cases one level of nesting") { + val nestedField = new StructField( + "nested", + new StructType() + .add("col1", INTEGER) + .add("col2", INTEGER), + true) + val testSchema = new StructType() + .add(nestedField) + .add("top_level_col", INTEGER) + // no columns pruned + checkPruneStatsSchema( + testSchema, + Set(col("top_level_col"), nestedCol("nested.col1"), nestedCol("nested.col2")), + testSchema) + // top level column pruned + checkPruneStatsSchema( + testSchema, + Set(nestedCol("nested.col1"), nestedCol("nested.col2")), + new StructType().add(nestedField)) + // nested column only one field pruned + checkPruneStatsSchema( + testSchema, + Set(nestedCol("top_level_col"), nestedCol("nested.col1")), + new StructType() + .add("nested", new StructType().add("col1", INTEGER)) + .add("top_level_col", INTEGER)) + // nested column completely pruned + checkPruneStatsSchema( + testSchema, + Set(nestedCol("top_level_col")), + new StructType().add("top_level_col", INTEGER)) + // prune all columns + checkPruneStatsSchema( + testSchema, + Set(), + new StructType()) + } + + test("pruneStatsSchema - 3 levels of nesting") { + /* + |--level1: struct + | |--level2: struct + | |--level3: struct + | |--level_4_col: int + | |--level_3_col: int + | |--level_2_col: int + */ + val testSchema = new StructType() + .add( + "level1", + new StructType() + .add( + "level2", + new StructType() + .add( + "level3", + new StructType().add("level_4_col", INTEGER)) + .add("level_3_col", INTEGER)) + .add("level_2_col", INTEGER)) + // prune only 4th level col + checkPruneStatsSchema( + testSchema, + Set(nestedCol("level1.level2.level_3_col"), nestedCol("level1.level_2_col")), + new StructType() + .add( + "level1", + new StructType() + .add("level2", new StructType().add("level_3_col", INTEGER)) + .add("level_2_col", INTEGER))) + // prune only 3rd level column + checkPruneStatsSchema( + testSchema, + Set(nestedCol("level1.level2.level3.level_4_col"), nestedCol("level1.level_2_col")), + new StructType() + .add( + "level1", + new StructType() + .add( + "level2", + new StructType() + .add( + "level3", + new StructType().add("level_4_col", INTEGER))) + .add("level_2_col", INTEGER))) + // prune 4th and 3rd level column + checkPruneStatsSchema( + testSchema, + Set(nestedCol("level1.level_2_col")), + new StructType() + .add( + "level1", + new StructType() + .add("level_2_col", INTEGER))) + // prune all columns + checkPruneStatsSchema( + testSchema, + Set(), + new StructType()) + } + + // TODO: add tests for remaining operators + test("check constructDataSkippingFilter") { + val testCases = Seq( + // (schema, predicate, expectedDataSkippingPredicateOpt) + ( + new StructType() + .add("a", StringType.STRING) + .add("b", StringType.STRING), + createPredicate("<", col("a"), col("b"), Optional.empty[CollationIdentifier]), + None), + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", StringType.STRING), + createPredicate("<", col("a"), literal("x"), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + "<", + Seq(nestedCol(s"$MIN.a"), literal("x")), + Set(nestedCol(s"$MIN.a"))))), + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", StringType.STRING), + createPredicate("<", literal("x"), col("a"), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + ">", + Seq(nestedCol(s"$MAX.a"), literal("x")), + Set(nestedCol(s"$MAX.a"))))), + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", StringType.STRING), + createPredicate(">", col("a"), literal("x"), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + ">", + Seq(nestedCol(s"$MAX.a"), literal("x")), + Set(nestedCol(s"$MAX.a"))))), + ( + new StructType() + .add("a", IntegerType.INTEGER), + createPredicate("=", col("a"), literal(10), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.a"), literal(10)), + Set(nestedCol(s"$MIN.a"))), + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(10)), + Set(nestedCol(s"$MAX.a")))))), + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate("<", col("a"), literal(10), Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(10)), + Set(nestedCol(s"$MAX.a"))))), + // NOT over AND: NOT(a < 5 AND a > 10) => (max.a >= 5) OR (min.a <= 10) + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "AND", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate(">", col("a"), literal(10), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + "OR", + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))), + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.a"), literal(10)), + Set(nestedCol(s"$MIN.a")))))), + // NOT over OR: NOT(a < 5 OR a > 10) => (max.a >= 5) AND (min.a <= 10) + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "OR", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate(">", col("a"), literal(10), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))), + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.a"), literal(10)), + Set(nestedCol(s"$MIN.a")))))), + // NOT over OR with one ineligible leg: NOT(a < b OR a < 5) => NOT(a < b) AND NOT(a < 5) + // The first leg is ineligible; AND with single leg should return that leg only + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "OR", + createPredicate("<", col("a"), col("b"), Optional.empty[CollationIdentifier]), + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))))), + // NOT over AND with one ineligible leg: NOT(a < 5 AND a < b) + // => NOT(a < 5) OR NOT(a < b); since OR needs both legs, expect None + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "AND", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate("<", col("a"), col("b"), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + None), + // Double NOT elimination: NOT(NOT(a < 5)) => a < 5 => min.a < 5 + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + new Predicate( + "NOT", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]))), + Some(dataSkippingPredicate( + "<", + Seq(nestedCol(s"$MIN.a"), literal(5)), + Set(nestedCol(s"$MIN.a"))))), + // Cross-column case: NOT(a < 5 OR b > 7) => (max.a >= 5) AND (min.b <= 7) + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "OR", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate(">", col("b"), literal(7), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))), + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.b"), literal(7)), + Set(nestedCol(s"$MIN.b"))))))) + + testCases.foreach { case (schema, predicate, expectedDataSkippingPredicateOpt) => + val dataSkippingPredicateOpt = optionalToScala(constructDataSkippingFilter(predicate, schema)) + (dataSkippingPredicateOpt, expectedDataSkippingPredicateOpt) match { + case (Some(dataSkippingPredicate), Some(expectedDataSkippingPredicate)) => + assert(dataSkippingPredicate == expectedDataSkippingPredicate) + case (None, None) => // pass + case _ => + fail(s"Expected $expectedDataSkippingPredicateOpt, found $dataSkippingPredicateOpt") + } + } + } + + test("check constructDataSkippingFilter with collations") { + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE") + + val testCases = Seq( + // (schema, predicate, expectedDataSkippingPredicateOpt) + // Ineligible: both sides are columns + ( + new StructType() + .add("a", StringType.STRING) + .add("b", StringType.STRING), + createPredicate("<", col("a"), col("b"), Optional.of(utf8Lcase)), + None), + // Eligible: a < "m" with collation -> min(a, collation) < "m" + ( + new StructType() + .add("a", StringType.STRING) + .add("b", StringType.STRING), + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase)), { + val minA = collatedStatsCol(utf8Lcase, MIN, "a") + Some(dataSkippingPredicateWithCollation( + "<", + Seq(minA, literal("m")), + utf8Lcase, + Set(minA))) + }), + // Reversed comparator: "m" < a -> max(a, collation) > "m" + ( + new StructType() + .add("a", StringType.STRING), + createPredicate("<", literal("m"), col("a"), Optional.of(utf8Lcase)), { + val maxA = collatedStatsCol(utf8Lcase, MAX, "a") + Some(dataSkippingPredicateWithCollation( + ">", + Seq(maxA, literal("m")), + utf8Lcase, + Set(maxA))) + }), + // Direct ">": a > "m" -> max(a, collation) > "m" + ( + new StructType() + .add("a", StringType.STRING), + createPredicate(">", col("a"), literal("m"), Optional.of(utf8Lcase)), { + val maxA = collatedStatsCol(utf8Lcase, MAX, "a") + Some(dataSkippingPredicateWithCollation( + ">", + Seq(maxA, literal("m")), + utf8Lcase, + Set(maxA))) + }), + // Equality + ( + new StructType() + .add("a", StringType.STRING), + createPredicate("=", col("a"), literal("abc"), Optional.of(unicode)), { + val minA = collatedStatsCol(unicode, MIN, "a") + val maxA = collatedStatsCol(unicode, MAX, "a") + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicateWithCollation("<=", Seq(minA, literal("abc")), unicode, Set(minA)), + dataSkippingPredicateWithCollation( + ">=", + Seq(maxA, literal("abc")), + unicode, + Set(maxA)))) + }), + // NOT over comparator: NOT(a < "m") -> max(a, collation) >= "m" + ( + new StructType() + .add("a", StringType.STRING), + new Predicate( + "NOT", + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase))), { + val maxA = collatedStatsCol(utf8Lcase, MAX, "a") + Some(dataSkippingPredicateWithCollation( + ">=", + Seq(maxA, literal("m")), + utf8Lcase, + Set(maxA))) + }), + // NOT over AND + // NOT(a < "m" AND a > "t") => (max.a >= "m") OR (min.a <= "t") + ( + new StructType() + .add("a", StringType.STRING), + new Predicate( + "NOT", + createPredicate( + "AND", + createPredicate("<", col("a"), literal("m"), Optional.of(unicode)), + createPredicate(">", col("a"), literal("t"), Optional.of(utf8Lcase)), + Optional.empty[CollationIdentifier])), { + val unicodeMaxA = collatedStatsCol(unicode, MAX, "a") + val utf8LcaseMinA = collatedStatsCol(utf8Lcase, MIN, "a") + Some(dataSkippingPredicate( + "OR", + dataSkippingPredicateWithCollation( + ">=", + Seq(unicodeMaxA, literal("m")), + unicode, + Set(unicodeMaxA)), + dataSkippingPredicateWithCollation( + "<=", + Seq(utf8LcaseMinA, literal("t")), + utf8Lcase, + Set(utf8LcaseMinA)))) + }), + // AND(a < "m" COLLATE UTF8_LCASE, b < 1) + ( + new StructType() + .add("a", StringType.STRING) + .add("b", IntegerType.INTEGER), + createPredicate( + "AND", + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase)), + createPredicate("<", col("b"), literal(1), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier]), { + val minA = collatedStatsCol(utf8Lcase, MIN, "a") + val minB = nestedCol(s"$MIN.b") + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicateWithCollation("<", Seq(minA, literal("m")), utf8Lcase, Set(minA)), + dataSkippingPredicate("<", Seq(minB, literal(1)), Set(minB)))) + }), + // Ineligible: non-string column with collation + ( + new StructType() + .add("a", IntegerType.INTEGER), + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase)), + None)) + + testCases.foreach { case (schema, predicate, expectedDataSkippingPredicateOpt) => + val dataSkippingPredicateOpt = optionalToScala(constructDataSkippingFilter(predicate, schema)) + (dataSkippingPredicateOpt, expectedDataSkippingPredicateOpt) match { + case (Some(dataSkippingPredicate), Some(expectedDataSkippingPredicate)) => + assert(dataSkippingPredicate == expectedDataSkippingPredicate) + case (None, None) => // pass + case _ => + fail(s"Expected $expectedDataSkippingPredicateOpt, found $dataSkippingPredicateOpt") + } + } + } +} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/PartitionUtilsSuite.scala similarity index 99% rename from kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala rename to kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/PartitionUtilsSuite.scala index 430d0bdd178..e182f7df855 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/PartitionUtilsSuite.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.delta.kernel.internal.util +package io.delta.kernel.internal.skipping import java.util @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import io.delta.kernel.expressions._ import io.delta.kernel.expressions.Literal._ -import io.delta.kernel.internal.util.PartitionUtils._ +import io.delta.kernel.internal.skipping.PartitionUtils._ import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala new file mode 100644 index 00000000000..4232ff54fda --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala @@ -0,0 +1,432 @@ +/* + * Copyright (2025) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal.skipping + +import scala.collection.JavaConverters.setAsJavaSetConverter + +import io.delta.kernel.types.{ArrayType, BinaryType, BooleanType, ByteType, CollationIdentifier, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} + +import org.scalatest.funsuite.AnyFunSuite + +class StatsSchemaHelperSuite extends AnyFunSuite { + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE") + + test("check getStatsSchema for supported data types") { + val testCases = Seq( + ( + new StructType().add("a", IntegerType.INTEGER), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("a", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("b", StringType.STRING), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("b", StringType.STRING, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("b", StringType.STRING, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("b", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("c", ByteType.BYTE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("c", ByteType.BYTE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("c", ByteType.BYTE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("c", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("d", ShortType.SHORT), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("d", ShortType.SHORT, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("d", ShortType.SHORT, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("d", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("e", LongType.LONG), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("f", FloatType.FLOAT), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("f", FloatType.FLOAT, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("f", FloatType.FLOAT, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("f", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("g", DoubleType.DOUBLE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("g", DoubleType.DOUBLE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("g", DoubleType.DOUBLE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("g", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("h", DateType.DATE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("h", DateType.DATE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("h", DateType.DATE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("h", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("i", TimestampType.TIMESTAMP), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("i", TimestampType.TIMESTAMP, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("i", TimestampType.TIMESTAMP, true), + true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("i", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ, true), + true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("j", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("k", new DecimalType(20, 5)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("k", new DecimalType(20, 5), true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("k", new DecimalType(20, 5), true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("k", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true))) + + testCases.foreach { case (dataSchema, expectedStatsSchema) => + val statsSchema = StatsSchemaHelper.getStatsSchema( + dataSchema, + Set.empty[CollationIdentifier].asJava) + assert( + statsSchema == expectedStatsSchema, + s"Stats schema mismatch for data schema: $dataSchema") + } + } + + test("check getStatsSchema with mix of supported and unsupported data types") { + val testCases = Seq( + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", BinaryType.BINARY) + .add("c", new ArrayType(LongType.LONG, true)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("a", IntegerType.INTEGER, true), true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType() + .add( + "s", + new StructType() + .add("s1", StringType.STRING) + .add("s2", BooleanType.BOOLEAN)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("s", new StructType().add("s1", StringType.STRING, true), true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("s", new StructType().add("s1", StringType.STRING, true), true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("s1", LongType.LONG, true) + .add("s2", LongType.LONG, true), + true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + // Un-nested array/map alongside a supported type + ( + new StructType() + .add("arr", new ArrayType(IntegerType.INTEGER, true)) + .add("mp", new MapType(StringType.STRING, LongType.LONG, true)) + .add("z", DoubleType.DOUBLE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("z", DoubleType.DOUBLE, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("z", DoubleType.DOUBLE, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true) + .add("z", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + // Nested array/map inside a struct; empty struct preserved in min/max + ( + new StructType() + .add( + "s", + new StructType() + .add("arr", new ArrayType(StringType.STRING, true)) + .add("mp", new MapType(IntegerType.INTEGER, StringType.STRING, true))) + .add("k", StringType.STRING), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("s", new StructType(), true) + .add("k", StringType.STRING, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("s", new StructType(), true) + .add("k", StringType.STRING, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true), + true) + .add("k", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true))) + + testCases.foreach { case (dataSchema, expectedStatsSchema) => + val statsSchema = StatsSchemaHelper.getStatsSchema( + dataSchema, + Set.empty[CollationIdentifier].asJava) + assert( + statsSchema == expectedStatsSchema, + s"Stats schema mismatch for data schema: $dataSchema") + } + } + + test("check getStatsSchema with collations - un-nested mix") { + val dataSchema = new StructType() + .add("a", StringType.STRING) + .add("b", IntegerType.INTEGER) + .add("c", BinaryType.BINARY) + + val collations = Set(utf8Lcase) + + val expectedCollatedMinMax = new StructType().add("a", StringType.STRING, true) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("a", StringType.STRING, true) + .add("b", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("a", StringType.STRING, true) + .add("b", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + .add( + StatsSchemaHelper.STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedMinMax, true) + .add(StatsSchemaHelper.MAX, expectedCollatedMinMax, true), + true), + true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } + + test("check getStatsSchema with collations - nested mix and multiple collations") { + val dataSchema = new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING) + .add("y", IntegerType.INTEGER) + .add("z", new StructType().add("p", StringType.STRING).add("q", DoubleType.DOUBLE))) + .add("arr", new ArrayType(StringType.STRING, true)) + .add("mp", new MapType(StringType.STRING, StringType.STRING, true)) + + val collations = Set(utf8Lcase, CollationIdentifier.SPARK_UTF8_BINARY) + + val expectedCollatedNested = new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("z", new StructType().add("p", StringType.STRING, true), true), + true) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("y", IntegerType.INTEGER, true) + .add( + "z", + new StructType() + .add("p", StringType.STRING, true) + .add("q", DoubleType.DOUBLE, true), + true), + true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("y", IntegerType.INTEGER, true) + .add( + "z", + new StructType() + .add("p", StringType.STRING, true) + .add("q", DoubleType.DOUBLE, true), + true), + true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("x", LongType.LONG, true) + .add("y", LongType.LONG, true) + .add( + "z", + new StructType() + .add("p", LongType.LONG, true) + .add("q", LongType.LONG, true), + true), + true) + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + .add( + StatsSchemaHelper.STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedNested, true) + .add(StatsSchemaHelper.MAX, expectedCollatedNested, true), + true) + .add( + CollationIdentifier.SPARK_UTF8_BINARY.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedNested, true) + .add(StatsSchemaHelper.MAX, expectedCollatedNested, true), + true), + true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } + + test("check getStatsSchema with collations - no eligible string columns") { + val dataSchema = new StructType() + .add("a", IntegerType.INTEGER) + .add("b", new ArrayType(StringType.STRING, true)) + .add("c", new MapType(StringType.STRING, LongType.LONG, true)) + + val collations = Set(utf8Lcase, unicode, CollationIdentifier.SPARK_UTF8_BINARY) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("a", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("a", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } +} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala deleted file mode 100644 index 27d2895f7f2..00000000000 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.internal.util - -import scala.collection.JavaConverters._ - -import io.delta.kernel.expressions.Column -import io.delta.kernel.internal.skipping.DataSkippingUtils -import io.delta.kernel.types.{DataType, StructField, StructType} -import io.delta.kernel.types.IntegerType.INTEGER - -import org.scalatest.funsuite.AnyFunSuite - -class DataSkippingUtilsSuite extends AnyFunSuite { - - def col(name: String): Column = new Column(name) - - def nestedCol(name: String): Column = { - new Column(name.split("\\.")) - } - - /* For struct type checks for equality based on field names & data type only */ - def compareDataTypeUnordered(type1: DataType, type2: DataType): Boolean = (type1, type2) match { - case (schema1: StructType, schema2: StructType) => - val fields1 = schema1.fields().asScala.sortBy(_.getName) - val fields2 = schema2.fields().asScala.sortBy(_.getName) - if (fields1.length != fields2.length) { - false - } else { - fields1.zip(fields2).forall { case (field1: StructField, field2: StructField) => - field1.getName == field2.getName && - compareDataTypeUnordered(field1.getDataType, field2.getDataType) - } - } - case _ => - type1 == type2 - } - - def checkPruneStatsSchema( - inputSchema: StructType, - referencedCols: Set[Column], - expectedSchema: StructType): Unit = { - val prunedSchema = DataSkippingUtils.pruneStatsSchema(inputSchema, referencedCols.asJava) - assert( - compareDataTypeUnordered(expectedSchema, prunedSchema), - s"expected=$expectedSchema\nfound=$prunedSchema") - } - - test("pruneStatsSchema - multiple basic cases one level of nesting") { - val nestedField = new StructField( - "nested", - new StructType() - .add("col1", INTEGER) - .add("col2", INTEGER), - true) - val testSchema = new StructType() - .add(nestedField) - .add("top_level_col", INTEGER) - // no columns pruned - checkPruneStatsSchema( - testSchema, - Set(col("top_level_col"), nestedCol("nested.col1"), nestedCol("nested.col2")), - testSchema) - // top level column pruned - checkPruneStatsSchema( - testSchema, - Set(nestedCol("nested.col1"), nestedCol("nested.col2")), - new StructType().add(nestedField)) - // nested column only one field pruned - checkPruneStatsSchema( - testSchema, - Set(nestedCol("top_level_col"), nestedCol("nested.col1")), - new StructType() - .add("nested", new StructType().add("col1", INTEGER)) - .add("top_level_col", INTEGER)) - // nested column completely pruned - checkPruneStatsSchema( - testSchema, - Set(nestedCol("top_level_col")), - new StructType().add("top_level_col", INTEGER)) - // prune all columns - checkPruneStatsSchema( - testSchema, - Set(), - new StructType()) - } - - test("pruneStatsSchema - 3 levels of nesting") { - /* - |--level1: struct - | |--level2: struct - | |--level3: struct - | |--level_4_col: int - | |--level_3_col: int - | |--level_2_col: int - */ - val testSchema = new StructType() - .add( - "level1", - new StructType() - .add( - "level2", - new StructType() - .add( - "level3", - new StructType().add("level_4_col", INTEGER)) - .add("level_3_col", INTEGER)) - .add("level_2_col", INTEGER)) - // prune only 4th level col - checkPruneStatsSchema( - testSchema, - Set(nestedCol("level1.level2.level_3_col"), nestedCol("level1.level_2_col")), - new StructType() - .add( - "level1", - new StructType() - .add("level2", new StructType().add("level_3_col", INTEGER)) - .add("level_2_col", INTEGER))) - // prune only 3rd level column - checkPruneStatsSchema( - testSchema, - Set(nestedCol("level1.level2.level3.level_4_col"), nestedCol("level1.level_2_col")), - new StructType() - .add( - "level1", - new StructType() - .add( - "level2", - new StructType() - .add( - "level3", - new StructType().add("level_4_col", INTEGER))) - .add("level_2_col", INTEGER))) - // prune 4th and 3rd level column - checkPruneStatsSchema( - testSchema, - Set(nestedCol("level1.level_2_col")), - new StructType() - .add( - "level1", - new StructType() - .add("level_2_col", INTEGER))) - // prune all columns - checkPruneStatsSchema( - testSchema, - Set(), - new StructType()) - } -} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala new file mode 100644 index 00000000000..b9448fcbc3c --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala @@ -0,0 +1,56 @@ +/* + * Copyright (2025) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.test + +import java.util.Optional + +import scala.collection.JavaConverters.asJavaIterableConverter + +import io.delta.kernel.expressions.{Column, Expression, Literal} +import io.delta.kernel.internal.skipping.DataSkippingPredicate + +/** Utility functions for tests. */ +trait TestUtils { + + def col(name: String): Column = new Column(name) + + def nestedCol(name: String): Column = { + new Column(name.split("\\.")) + } + + def literal(value: Any): Literal = { + value match { + case v: String => Literal.ofString(v) + case v: Int => Literal.ofInt(v) + case v: Long => Literal.ofLong(v) + case v: Float => Literal.ofFloat(v) + case v: Double => Literal.ofDouble(v) + case v: Boolean => Literal.ofBoolean(v) + case _ => throw new IllegalArgumentException(s"Unsupported literal type: ${value}") + } + } + + protected def optionToJava[T](option: Option[T]): Optional[T] = { + option match { + case Some(value) => Optional.of(value) + case None => Optional.empty() + } + } + + protected def optionalToScala[T](optional: Optional[T]): Option[T] = { + if (optional.isPresent) Some(optional.get()) else None + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index f82729a3663..1fd4abfefe7 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -30,6 +30,9 @@ import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; import io.delta.kernel.engine.ExpressionHandler; import io.delta.kernel.expressions.*; +import io.delta.kernel.internal.skipping.DataSkippingPredicate; +import io.delta.kernel.internal.skipping.PartitionPredicate; +import io.delta.kernel.internal.util.Tuple2; import io.delta.kernel.types.*; import java.util.*; import java.util.stream.Collectors; @@ -77,10 +80,19 @@ public void close() { private static class ExpressionTransformResult { public final Expression expression; // transformed expression public final DataType outputType; // output type of the expression + public final boolean hasNonUTF8BinaryCollationIgnored; ExpressionTransformResult(Expression expression, DataType outputType) { this.expression = expression; this.outputType = outputType; + this.hasNonUTF8BinaryCollationIgnored = false; + } + + ExpressionTransformResult( + Expression expression, DataType outputType, boolean hasNonUTF8BinaryCollationIgnored) { + this.expression = expression; + this.outputType = outputType; + this.hasNonUTF8BinaryCollationIgnored = hasNonUTF8BinaryCollationIgnored; } } @@ -99,23 +111,54 @@ private static class ExpressionTransformResult { */ private static class ExpressionTransformer extends ExpressionVisitor { private StructType inputDataSchema; + private boolean shouldIgnoreNonUTF8BinaryComparisons; ExpressionTransformer(StructType inputDataSchema) { this.inputDataSchema = requireNonNull(inputDataSchema, "inputDataSchema is null"); + shouldIgnoreNonUTF8BinaryComparisons = false; } @Override ExpressionTransformResult visitAnd(And and) { - Predicate left = validateIsPredicate(and, visit(and.getLeft())); - Predicate right = validateIsPredicate(and, visit(and.getRight())); - return new ExpressionTransformResult(new And(left, right), BooleanType.BOOLEAN); + ExpressionTransformResult leftResult = visit(and.getLeft()); + ExpressionTransformResult rightResult = visit(and.getRight()); + Predicate left = validateIsPredicate(and, leftResult); + Predicate right = validateIsPredicate(and, rightResult); + + if (left != AlwaysTrue.ALWAYS_TRUE && right != AlwaysTrue.ALWAYS_TRUE) { + return new ExpressionTransformResult( + new And(left, right), + BooleanType.BOOLEAN, + leftResult.hasNonUTF8BinaryCollationIgnored + || rightResult.hasNonUTF8BinaryCollationIgnored); + } else if (left == AlwaysTrue.ALWAYS_TRUE && right == AlwaysTrue.ALWAYS_TRUE) { + return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); + } else if (left == AlwaysTrue.ALWAYS_TRUE) { + return new ExpressionTransformResult( + right, BooleanType.BOOLEAN, rightResult.hasNonUTF8BinaryCollationIgnored); + } else { + return new ExpressionTransformResult( + left, BooleanType.BOOLEAN, leftResult.hasNonUTF8BinaryCollationIgnored); + } } @Override ExpressionTransformResult visitOr(Or or) { - Predicate left = validateIsPredicate(or, visit(or.getLeft())); - Predicate right = validateIsPredicate(or, visit(or.getRight())); - return new ExpressionTransformResult(new Or(left, right), BooleanType.BOOLEAN); + ExpressionTransformResult leftResult = visit(or.getLeft()); + ExpressionTransformResult rightResult = visit(or.getRight()); + Predicate left = validateIsPredicate(or, leftResult); + Predicate right = validateIsPredicate(or, rightResult); + + boolean hasNonUTF8BinaryCollation = + leftResult.hasNonUTF8BinaryCollationIgnored + || rightResult.hasNonUTF8BinaryCollationIgnored; + if (left == AlwaysTrue.ALWAYS_TRUE || right == AlwaysTrue.ALWAYS_TRUE) { + return new ExpressionTransformResult( + AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, hasNonUTF8BinaryCollation); + } else { + return new ExpressionTransformResult( + new Or(left, right), BooleanType.BOOLEAN, hasNonUTF8BinaryCollation); + } } @Override @@ -139,8 +182,8 @@ ExpressionTransformResult visitComparator(Predicate predicate) { case "<": case "<=": case "IS NOT DISTINCT FROM": - return new ExpressionTransformResult( - transformBinaryComparator(predicate), BooleanType.BOOLEAN); + Tuple2 result = transformBinaryComparator(predicate); + return new ExpressionTransformResult(result._1, BooleanType.BOOLEAN, result._2); default: // We should never reach this based on the ExpressionVisitor throw new IllegalStateException( @@ -213,23 +256,41 @@ ExpressionTransformResult visitElementAt(ScalarExpression elementAt) { @Override ExpressionTransformResult visitNot(Predicate predicate) { - Predicate child = validateIsPredicate(predicate, visit(predicate.getChildren().get(0))); - return new ExpressionTransformResult( - new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); + ExpressionTransformResult childResult = visit(predicate.getChildren().get(0)); + Predicate child = validateIsPredicate(predicate, childResult); + + if (childResult.hasNonUTF8BinaryCollationIgnored) { + return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); + } else { + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); + } } @Override ExpressionTransformResult visitIsNotNull(Predicate predicate) { - Expression child = visit(predicate.getChildren().get(0)).expression; - return new ExpressionTransformResult( - new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); + ExpressionTransformResult childResult = visit(predicate.getChildren().get(0)); + Expression child = childResult.expression; + + if (childResult.hasNonUTF8BinaryCollationIgnored) { + return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); + } else { + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); + } } @Override ExpressionTransformResult visitIsNull(Predicate predicate) { - Expression child = visit(getUnaryChild(predicate)).expression; - return new ExpressionTransformResult( - new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); + ExpressionTransformResult childResult = visit(getUnaryChild(predicate)); + Expression child = childResult.expression; + + if (childResult.hasNonUTF8BinaryCollationIgnored) { + return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); + } else { + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); + } } @Override @@ -346,6 +407,83 @@ ExpressionTransformResult visitStartsWith(Predicate startsWith) { return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN); } + @Override + ExpressionTransformResult visitDataSkippingPredicate( + DataSkippingPredicate dataSkippingPredicate) { + if (dataSkippingPredicate.getReferencedCollations().isEmpty()) { + return visitScalarExpression(dataSkippingPredicate); + } + + Predicate resolvedPredicate = dataSkippingPredicate; + boolean hasNonUTF8BinaryCollation = false; + String name = dataSkippingPredicate.getName().toUpperCase(Locale.ROOT); + switch (name) { + case "AND": + validateChildrenAreDataSkippingPredicates(dataSkippingPredicate); + ExpressionTransformResult leftResult = + visitDataSkippingPredicate((DataSkippingPredicate) getLeft(dataSkippingPredicate)); + ExpressionTransformResult rightResult = + visitDataSkippingPredicate((DataSkippingPredicate) getRight(dataSkippingPredicate)); + if (leftResult.hasNonUTF8BinaryCollationIgnored + && rightResult.hasNonUTF8BinaryCollationIgnored) { + resolvedPredicate = AlwaysTrue.ALWAYS_TRUE; + } else if (!leftResult.hasNonUTF8BinaryCollationIgnored + && !rightResult.hasNonUTF8BinaryCollationIgnored) { + resolvedPredicate = dataSkippingPredicate; + } else if (leftResult.hasNonUTF8BinaryCollationIgnored) { + resolvedPredicate = validateIsPredicate(rightResult.expression, rightResult); + } else { + resolvedPredicate = validateIsPredicate(leftResult.expression, leftResult); + } + hasNonUTF8BinaryCollation = + leftResult.hasNonUTF8BinaryCollationIgnored + || rightResult.hasNonUTF8BinaryCollationIgnored; + break; + + case "OR": + validateChildrenAreDataSkippingPredicates(dataSkippingPredicate); + leftResult = + visitDataSkippingPredicate((DataSkippingPredicate) getLeft(dataSkippingPredicate)); + rightResult = + visitDataSkippingPredicate((DataSkippingPredicate) getRight(dataSkippingPredicate)); + + if (leftResult.hasNonUTF8BinaryCollationIgnored + || rightResult.hasNonUTF8BinaryCollationIgnored) { + resolvedPredicate = AlwaysTrue.ALWAYS_TRUE; + hasNonUTF8BinaryCollation = true; + } else { + resolvedPredicate = dataSkippingPredicate; + hasNonUTF8BinaryCollation = false; + } + break; + + case "=": + case "<": + case "<=": + case ">": + case ">=": + case "IS NOT DISTINCT FROM": + hasNonUTF8BinaryCollation = + dataSkippingPredicate + .getCollationIdentifier() + .map(id -> !id.isSparkUTF8BinaryCollation()) + .orElse(false); + resolvedPredicate = + hasNonUTF8BinaryCollation ? AlwaysTrue.ALWAYS_TRUE : dataSkippingPredicate; + break; + } + + ExpressionTransformResult result = visitScalarExpression(resolvedPredicate); + return new ExpressionTransformResult( + result.expression, result.outputType, hasNonUTF8BinaryCollation); + } + + @Override + ExpressionTransformResult visitPartitionPredicate(PartitionPredicate partitionPredicate) { + shouldIgnoreNonUTF8BinaryComparisons = true; + return visitScalarExpression(partitionPredicate); + } + private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { checkArgument( @@ -357,23 +495,45 @@ private Predicate validateIsPredicate( return (Predicate) result.expression; } - private Expression transformBinaryComparator(Predicate predicate) { + private void validateChildrenAreDataSkippingPredicates( + DataSkippingPredicate dataSkippingPredicate) { + for (Expression child : dataSkippingPredicate.getChildren()) { + checkArgument( + child instanceof DataSkippingPredicate, + "%s: expected children to be DataSkippingPredicate but got %s", + dataSkippingPredicate, + child); + } + } + + private Tuple2 transformBinaryComparator(Predicate predicate) { ExpressionTransformResult leftResult = visit(getLeft(predicate)); ExpressionTransformResult rightResult = visit(getRight(predicate)); Expression left = leftResult.expression; Expression right = rightResult.expression; if (predicate.getCollationIdentifier().isPresent()) { - CollationIdentifier collationIdentifier = predicate.getCollationIdentifier().get(); - checkIsUTF8BinaryCollation(predicate, collationIdentifier); - for (DataType dataType : Arrays.asList(leftResult.outputType, rightResult.outputType)) { checkIsStringType( dataType, predicate, format("Predicate %s expects STRING type inputs", predicate.getName())); } - return new Predicate(predicate.getName(), left, right, collationIdentifier); + + CollationIdentifier collationIdentifier = predicate.getCollationIdentifier().get(); + if (!shouldIgnoreNonUTF8BinaryComparisons) { + checkIsUTF8BinaryCollation(predicate, collationIdentifier); + return new Tuple2( + new Predicate(predicate.getName(), left, right, collationIdentifier), false); + } else { + if (collationIdentifier.isSparkUTF8BinaryCollation()) { + return new Tuple2( + new Predicate(predicate.getName(), left, right, collationIdentifier), false); + } else { + // Ignore non-UTF8_BINARY collation in partition predicates + return new Tuple2(AlwaysTrue.ALWAYS_TRUE, true); + } + } } if (!leftResult.outputType.equivalent(rightResult.outputType)) { @@ -390,7 +550,13 @@ private Expression transformBinaryComparator(Predicate predicate) { throw unsupportedExpressionException(predicate, msg); } } - return new Predicate(predicate.getName(), left, right); + + if (leftResult.hasNonUTF8BinaryCollationIgnored + || rightResult.hasNonUTF8BinaryCollationIgnored) { + return new Tuple2(AlwaysTrue.ALWAYS_TRUE, true); + } else { + return new Tuple2(new Predicate(predicate.getName(), left, right), false); + } } } @@ -726,6 +892,16 @@ ColumnVector visitStartsWith(Predicate startsWith) { startsWith.getChildren().stream().map(this::visit).collect(toList())); } + @Override + ColumnVector visitDataSkippingPredicate(DataSkippingPredicate dataSkippingPredicate) { + throw new UnsupportedOperationException("DataSkippingPredicate expression is not expected."); + } + + @Override + ColumnVector visitPartitionPredicate(PartitionPredicate partitionPredicate) { + throw new UnsupportedOperationException("PartitionPredicate expression is not expected."); + } + /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index e3e784694de..4ef82a3fc1a 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -21,6 +21,8 @@ import static java.util.stream.Collectors.joining; import io.delta.kernel.expressions.*; +import io.delta.kernel.internal.skipping.DataSkippingPredicate; +import io.delta.kernel.internal.skipping.PartitionPredicate; import io.delta.kernel.types.CollationIdentifier; import java.util.List; import java.util.Locale; @@ -72,9 +74,17 @@ abstract class ExpressionVisitor { abstract R visitStartsWith(Predicate predicate); + abstract R visitDataSkippingPredicate(DataSkippingPredicate predicate); + + abstract R visitPartitionPredicate(PartitionPredicate predicate); + final R visit(Expression expression) { if (expression instanceof PartitionValueExpression) { return visitPartitionValue((PartitionValueExpression) expression); + } else if (expression instanceof PartitionPredicate) { + return visitPartitionPredicate((PartitionPredicate) expression); + } else if (expression instanceof DataSkippingPredicate) { + return visitDataSkippingPredicate((DataSkippingPredicate) expression); } else if (expression instanceof ScalarExpression) { return visitScalarExpression((ScalarExpression) expression); } else if (expression instanceof Literal) { @@ -89,7 +99,7 @@ final R visit(Expression expression) { String.format("Expression %s is not supported.", expression)); } - private R visitScalarExpression(ScalarExpression expression) { + R visitScalarExpression(ScalarExpression expression) { List children = expression.getChildren(); String name = expression.getName().toUpperCase(Locale.ENGLISH); Optional collationIdentifier = Optional.empty(); diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index e0503f2d958..5846561d702 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -1587,6 +1587,349 @@ class ScanSuite extends AnyFunSuite with TestUtils } } + test("data skipping - basic collated predicate on data column") { + withTempDir { tempDir => + Seq("a", "b").toDF("c1").repartition(1).write.format("delta").save(tempDir.getCanonicalPath) + Seq("n", "o").toDF( + "c1").repartition(1).write.format("delta").mode("append").save(tempDir.getCanonicalPath) + Seq("x", "y").toDF( + "c1").repartition(1).write.format("delta").mode("append").save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val filterToFileNumber = Map( + new Predicate( + ">", + col("c1"), + ofString("m"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 2, + new Predicate( + "<", + col("c1"), + ofString("m"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new Predicate("=", ofString("x"), col("c1"), CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + // Non-UTF8 binary collations should not be used for data skipping; expect all files + new Predicate( + "<", + col("c1"), + ofString("m"), + CollationIdentifier.fromString("SPARK.UTF8_LCASE")) -> totalFiles, + new Predicate( + "<", + ofString("m"), + col("c1"), + CollationIdentifier.fromString("ICU.ENGLISH.74.5")) -> totalFiles) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + + test("data skipping - complex collated predicate on data column") { + withTempDir { tempDir => + Seq("a", "b").toDF("c1").repartition(1).write.format("delta").save(tempDir.getCanonicalPath) + Seq("n", "o").toDF( + "c1").repartition(1).write.format("delta").mode("append").save(tempDir.getCanonicalPath) + Seq("x", "y").toDF( + "c1").repartition(1).write.format("delta").mode("append").save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val UTF8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.74.5") + + val filterToFileNumber = Map( + new And( + new Predicate(">", col("c1"), ofString("m"), unicode), + new Predicate("<", col("c1"), ofString("z"), UTF8Lcase)) -> totalFiles, + new Or( + new Predicate("=", ofString("x"), col("c1"), unicode), + new Predicate("=", ofString("b"), col("c1"), UTF8Lcase)) -> totalFiles, + new And( + new Predicate("=", ofString("x"), col("c1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("b"), col("c1"), UTF8Lcase)) -> 1, + new And( + new And( + new Predicate(">=", ofString("a"), col("c1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("b"), col("c1"), UTF8Lcase)), + new And( + new Predicate("<=", ofString("a"), col("c1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("b"), col("c1"), UTF8Lcase))) -> 1, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + col("c1"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + ofString("a"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + + test("data skipping - complex collated predicate on two data columns") { + withTempDir { tempDir => + Seq(("a", "b"), ("b", "a")).toDF("c1", "c2").repartition(1) + .write.format("delta").save(tempDir.getCanonicalPath) + Seq(("n", "a"), ("o", "b")).toDF("c1", "c2").repartition(1) + .write.format("delta").mode("append").save(tempDir.getCanonicalPath) + Seq(("x", "a"), ("y", "b")).toDF("c1", "c2").repartition(1) + .write.format("delta").mode("append").save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val UTF8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.74.5") + + val filterToFileNumber = Map( + new And( + new Predicate(">", col("c1"), ofString("m"), unicode), + new Predicate("<", col("c1"), ofString("z"), UTF8Lcase)) -> totalFiles, + new Or( + new Predicate("=", ofString("x"), col("c2"), unicode), + new Predicate("=", ofString("y"), col("c1"), UTF8Lcase)) -> totalFiles, + new Or( + new Predicate("=", ofString("x"), col("c2"), unicode), + new Predicate( + "=", + ofString("x"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> totalFiles, + new And( + new Predicate("=", ofString("x"), col("c1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("c2"), UTF8Lcase)) -> 1, + new And( + new And( + new Predicate("=", ofString("x"), col("c1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("c2"), UTF8Lcase)), + new And( + new Predicate("=", ofString("a"), col("c2"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("c2"), UTF8Lcase))) -> 1, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + col("c1"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + ofString("a"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + + test("partition pruning - basic predicate with SPARK.UTF8_BINARY on partition column") { + withTempDir { tempDir => + Seq( + ("a", "x"), + ("b", "y"), + ("c", "z"), + ("d", "x"), + ("e", "y"), + ("f", "z")).toDF("a", "p") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val filterToFileNumber = Map( + new Predicate( + ">", + col("p"), + ofString("m"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate( + "<", + ofString("m"), + col("p"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate("=", ofString("x"), col("p"), CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new Predicate("=", col("p"), ofString("x"), CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new Predicate("<", col("p"), ofString("m"), CollationIdentifier.SPARK_UTF8_BINARY) -> 0) + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + + test("partition pruning - basic collated predicate on partition column") { + withTempDir { tempDir => + Seq( + ("a", "x"), + ("b", "y"), + ("c", "z"), + ("d", "x"), + ("e", "y"), + ("f", "z")).toDF("a", "p") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val collations = Seq( + CollationIdentifier.fromString("SPARK.UTF8_LCASE"), + CollationIdentifier.fromString("ICU.ENGLISH.74.5"), + CollationIdentifier.fromString("ICU.FRENCH")) + + val filters = Seq( + new Predicate("<", col("p"), ofString("m")), + new Predicate("<", ofString("m"), col("p")), + new Predicate(">", col("p"), ofString("m")), + new Predicate("=", ofString("x"), col("p")), + // This predicate is evaluated as part of partition pruning + new Predicate("=", ofString("x"), ofString("y"))) + + for { + collation <- collations + filter <- filters + } { + val filterWithCollation = new Predicate(filter.getName, filter.getChildren, collation) + checkSkipping(tempDir.getCanonicalPath, Map(filterWithCollation -> totalFiles)) + } + } + } + + test("partition pruning - complex collated predicate on partition column") { + withTempDir { tempDir => + Seq( + ("a", "x"), + ("b", "y"), + ("c", "z"), + ("d", "x"), + ("e", "y"), + ("f", "z")).toDF("a", "p") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val UTF8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.74.5") + + val filterToFileNumber = Map( + new And( + new Predicate(">", col("p"), ofString("m"), unicode), + new Predicate("<", col("p"), ofString("z"), UTF8Lcase)) -> totalFiles, + new Or( + new Predicate("=", ofString("x"), col("p"), unicode), + new Predicate("=", ofString("y"), col("p"), UTF8Lcase)) -> totalFiles, + new And( + new Predicate("=", ofString("x"), col("p"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("y"), col("p"), UTF8Lcase)) -> 1, + new Predicate( + ">", + new Predicate("=", ofString("x"), col("p"), unicode), + AlwaysTrue.ALWAYS_TRUE) -> totalFiles, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + col("p"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + ofString("a"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + + test("partition pruning - complex collated predicate on two partition columns") { + withTempDir { tempDir => + Seq( + ("a", "x", "b"), + ("b", "y", "a"), + ("c", "z", "b"), + ("d", "y", "a"), + ("e", "x", "a"), + ("f", "y", "b"), + ("g", "z", "a"), + ("h", "x", "a")).toDF("a", "p1", "p2") + .write + .format("delta") + .partitionBy("p1", "p2") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + val UTF8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.74.5") + + val filterToFileNumber = Map( + new And( + new Predicate(">", col("p1"), ofString("m"), unicode), + new Predicate("<", col("p1"), ofString("z"), UTF8Lcase)) -> totalFiles, + new Or( + new Predicate("=", ofString("x"), col("p2"), unicode), + new Predicate("=", ofString("y"), col("p1"), UTF8Lcase)) -> totalFiles, + new Or( + new Predicate("=", ofString("x"), col("p2"), unicode), + new Predicate( + "=", + ofString("x"), + col("p1"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> totalFiles, + new And( + new Predicate("=", ofString("x"), col("p1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("p2"), UTF8Lcase)) -> 2, + new And( + new Predicate("=", ofString("a"), col("p2"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("p2"), UTF8Lcase)) -> 3, + new And( + new And( + new Predicate("=", ofString("x"), col("p1"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("p2"), UTF8Lcase)), + new And( + new Predicate("=", ofString("a"), col("p2"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("=", ofString("a"), col("p2"), UTF8Lcase))) -> 1, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + col("p1"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles, + new Predicate( + "NOT", + Seq[Expression](new Predicate( + "<", + ofString("a"), + ofString("m"), + UTF8Lcase)).asJava) -> totalFiles) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + Seq( "spark-variant-checkpoint", "spark-variant-stable-feature-checkpoint", From 3acc168dab2c729beba986f01af296c836ce48ca Mon Sep 17 00:00:00 2001 From: ilicmarkodb Date: Wed, 15 Oct 2025 17:43:20 +0200 Subject: [PATCH 2/2] temp --- .../UnsupportedPredicateWithCollation.java | 4 + .../io/delta/kernel/internal/DeltaErrors.java | 1 + .../io/delta/kernel/internal/ScanImpl.java | 26 +- .../DefaultExpressionEvaluator.java | 222 ++---------------- .../expressions/DefaultExpressionUtils.java | 7 +- .../expressions/ExpressionVisitor.java | 12 +- 6 files changed, 48 insertions(+), 224 deletions(-) create mode 100644 kernel/kernel-api/src/main/java/io/delta/kernel/exceptions/UnsupportedPredicateWithCollation.java diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/exceptions/UnsupportedPredicateWithCollation.java b/kernel/kernel-api/src/main/java/io/delta/kernel/exceptions/UnsupportedPredicateWithCollation.java new file mode 100644 index 00000000000..d6cd250e7eb --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/exceptions/UnsupportedPredicateWithCollation.java @@ -0,0 +1,4 @@ +package io.delta.kernel.exceptions; + +public class UnsupportedPredicateWithCollation extends KernelException { +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java index 04330918ce5..1304b95f628 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/DeltaErrors.java @@ -22,6 +22,7 @@ import io.delta.kernel.expressions.Column; import io.delta.kernel.internal.actions.DomainMetadata; import io.delta.kernel.internal.tablefeatures.TableFeature; +import io.delta.kernel.types.CollationIdentifier; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructType; import io.delta.kernel.types.TypeChange; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index c7acb31eb3f..613b9ebb10d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -25,6 +25,8 @@ import io.delta.kernel.Scan; import io.delta.kernel.data.*; import io.delta.kernel.engine.Engine; +import io.delta.kernel.exceptions.KernelEngineException; +import io.delta.kernel.exceptions.UnsupportedPredicateWithCollation; import io.delta.kernel.expressions.*; import io.delta.kernel.internal.actions.Metadata; import io.delta.kernel.internal.actions.Protocol; @@ -43,6 +45,7 @@ import io.delta.kernel.internal.util.*; import io.delta.kernel.metrics.ScanReport; import io.delta.kernel.metrics.SnapshotReport; +import io.delta.kernel.types.CollationIdentifier; import io.delta.kernel.types.MetadataColumnSpec; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; @@ -382,15 +385,20 @@ private CloseableIterator applyDataSkipping( "COALESCE", Arrays.asList(dataSkippingFilter, Literal.ofBoolean(true))), AlwaysTrue.ALWAYS_TRUE); - PredicateEvaluator predicateEvaluator = - wrapEngineException( - () -> - engine - .getExpressionHandler() - .getPredicateEvaluator(prunedStatsSchema, filterToEval), - "Get the predicate evaluator for data skipping with schema=%s and filter=%s", - prunedStatsSchema, - filterToEval); + PredicateEvaluator predicateEvaluator; + try { + predicateEvaluator = + wrapEngineException( + () -> + engine + .getExpressionHandler() + .getPredicateEvaluator(prunedStatsSchema, filterToEval), + "Get the predicate evaluator for data skipping with schema=%s and filter=%s", + prunedStatsSchema, + filterToEval); + } catch (UnsupportedPredicateWithCollation e) { + return scanFileIter; + } return scanFileIter.map( filteredScanFileBatch -> { diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index 1fd4abfefe7..f82729a3663 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -30,9 +30,6 @@ import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; import io.delta.kernel.engine.ExpressionHandler; import io.delta.kernel.expressions.*; -import io.delta.kernel.internal.skipping.DataSkippingPredicate; -import io.delta.kernel.internal.skipping.PartitionPredicate; -import io.delta.kernel.internal.util.Tuple2; import io.delta.kernel.types.*; import java.util.*; import java.util.stream.Collectors; @@ -80,19 +77,10 @@ public void close() { private static class ExpressionTransformResult { public final Expression expression; // transformed expression public final DataType outputType; // output type of the expression - public final boolean hasNonUTF8BinaryCollationIgnored; ExpressionTransformResult(Expression expression, DataType outputType) { this.expression = expression; this.outputType = outputType; - this.hasNonUTF8BinaryCollationIgnored = false; - } - - ExpressionTransformResult( - Expression expression, DataType outputType, boolean hasNonUTF8BinaryCollationIgnored) { - this.expression = expression; - this.outputType = outputType; - this.hasNonUTF8BinaryCollationIgnored = hasNonUTF8BinaryCollationIgnored; } } @@ -111,54 +99,23 @@ private static class ExpressionTransformResult { */ private static class ExpressionTransformer extends ExpressionVisitor { private StructType inputDataSchema; - private boolean shouldIgnoreNonUTF8BinaryComparisons; ExpressionTransformer(StructType inputDataSchema) { this.inputDataSchema = requireNonNull(inputDataSchema, "inputDataSchema is null"); - shouldIgnoreNonUTF8BinaryComparisons = false; } @Override ExpressionTransformResult visitAnd(And and) { - ExpressionTransformResult leftResult = visit(and.getLeft()); - ExpressionTransformResult rightResult = visit(and.getRight()); - Predicate left = validateIsPredicate(and, leftResult); - Predicate right = validateIsPredicate(and, rightResult); - - if (left != AlwaysTrue.ALWAYS_TRUE && right != AlwaysTrue.ALWAYS_TRUE) { - return new ExpressionTransformResult( - new And(left, right), - BooleanType.BOOLEAN, - leftResult.hasNonUTF8BinaryCollationIgnored - || rightResult.hasNonUTF8BinaryCollationIgnored); - } else if (left == AlwaysTrue.ALWAYS_TRUE && right == AlwaysTrue.ALWAYS_TRUE) { - return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); - } else if (left == AlwaysTrue.ALWAYS_TRUE) { - return new ExpressionTransformResult( - right, BooleanType.BOOLEAN, rightResult.hasNonUTF8BinaryCollationIgnored); - } else { - return new ExpressionTransformResult( - left, BooleanType.BOOLEAN, leftResult.hasNonUTF8BinaryCollationIgnored); - } + Predicate left = validateIsPredicate(and, visit(and.getLeft())); + Predicate right = validateIsPredicate(and, visit(and.getRight())); + return new ExpressionTransformResult(new And(left, right), BooleanType.BOOLEAN); } @Override ExpressionTransformResult visitOr(Or or) { - ExpressionTransformResult leftResult = visit(or.getLeft()); - ExpressionTransformResult rightResult = visit(or.getRight()); - Predicate left = validateIsPredicate(or, leftResult); - Predicate right = validateIsPredicate(or, rightResult); - - boolean hasNonUTF8BinaryCollation = - leftResult.hasNonUTF8BinaryCollationIgnored - || rightResult.hasNonUTF8BinaryCollationIgnored; - if (left == AlwaysTrue.ALWAYS_TRUE || right == AlwaysTrue.ALWAYS_TRUE) { - return new ExpressionTransformResult( - AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, hasNonUTF8BinaryCollation); - } else { - return new ExpressionTransformResult( - new Or(left, right), BooleanType.BOOLEAN, hasNonUTF8BinaryCollation); - } + Predicate left = validateIsPredicate(or, visit(or.getLeft())); + Predicate right = validateIsPredicate(or, visit(or.getRight())); + return new ExpressionTransformResult(new Or(left, right), BooleanType.BOOLEAN); } @Override @@ -182,8 +139,8 @@ ExpressionTransformResult visitComparator(Predicate predicate) { case "<": case "<=": case "IS NOT DISTINCT FROM": - Tuple2 result = transformBinaryComparator(predicate); - return new ExpressionTransformResult(result._1, BooleanType.BOOLEAN, result._2); + return new ExpressionTransformResult( + transformBinaryComparator(predicate), BooleanType.BOOLEAN); default: // We should never reach this based on the ExpressionVisitor throw new IllegalStateException( @@ -256,41 +213,23 @@ ExpressionTransformResult visitElementAt(ScalarExpression elementAt) { @Override ExpressionTransformResult visitNot(Predicate predicate) { - ExpressionTransformResult childResult = visit(predicate.getChildren().get(0)); - Predicate child = validateIsPredicate(predicate, childResult); - - if (childResult.hasNonUTF8BinaryCollationIgnored) { - return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); - } else { - return new ExpressionTransformResult( - new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); - } + Predicate child = validateIsPredicate(predicate, visit(predicate.getChildren().get(0))); + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); } @Override ExpressionTransformResult visitIsNotNull(Predicate predicate) { - ExpressionTransformResult childResult = visit(predicate.getChildren().get(0)); - Expression child = childResult.expression; - - if (childResult.hasNonUTF8BinaryCollationIgnored) { - return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); - } else { - return new ExpressionTransformResult( - new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); - } + Expression child = visit(predicate.getChildren().get(0)).expression; + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); } @Override ExpressionTransformResult visitIsNull(Predicate predicate) { - ExpressionTransformResult childResult = visit(getUnaryChild(predicate)); - Expression child = childResult.expression; - - if (childResult.hasNonUTF8BinaryCollationIgnored) { - return new ExpressionTransformResult(AlwaysTrue.ALWAYS_TRUE, BooleanType.BOOLEAN, true); - } else { - return new ExpressionTransformResult( - new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); - } + Expression child = visit(getUnaryChild(predicate)).expression; + return new ExpressionTransformResult( + new Predicate(predicate.getName(), child), BooleanType.BOOLEAN); } @Override @@ -407,83 +346,6 @@ ExpressionTransformResult visitStartsWith(Predicate startsWith) { return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN); } - @Override - ExpressionTransformResult visitDataSkippingPredicate( - DataSkippingPredicate dataSkippingPredicate) { - if (dataSkippingPredicate.getReferencedCollations().isEmpty()) { - return visitScalarExpression(dataSkippingPredicate); - } - - Predicate resolvedPredicate = dataSkippingPredicate; - boolean hasNonUTF8BinaryCollation = false; - String name = dataSkippingPredicate.getName().toUpperCase(Locale.ROOT); - switch (name) { - case "AND": - validateChildrenAreDataSkippingPredicates(dataSkippingPredicate); - ExpressionTransformResult leftResult = - visitDataSkippingPredicate((DataSkippingPredicate) getLeft(dataSkippingPredicate)); - ExpressionTransformResult rightResult = - visitDataSkippingPredicate((DataSkippingPredicate) getRight(dataSkippingPredicate)); - if (leftResult.hasNonUTF8BinaryCollationIgnored - && rightResult.hasNonUTF8BinaryCollationIgnored) { - resolvedPredicate = AlwaysTrue.ALWAYS_TRUE; - } else if (!leftResult.hasNonUTF8BinaryCollationIgnored - && !rightResult.hasNonUTF8BinaryCollationIgnored) { - resolvedPredicate = dataSkippingPredicate; - } else if (leftResult.hasNonUTF8BinaryCollationIgnored) { - resolvedPredicate = validateIsPredicate(rightResult.expression, rightResult); - } else { - resolvedPredicate = validateIsPredicate(leftResult.expression, leftResult); - } - hasNonUTF8BinaryCollation = - leftResult.hasNonUTF8BinaryCollationIgnored - || rightResult.hasNonUTF8BinaryCollationIgnored; - break; - - case "OR": - validateChildrenAreDataSkippingPredicates(dataSkippingPredicate); - leftResult = - visitDataSkippingPredicate((DataSkippingPredicate) getLeft(dataSkippingPredicate)); - rightResult = - visitDataSkippingPredicate((DataSkippingPredicate) getRight(dataSkippingPredicate)); - - if (leftResult.hasNonUTF8BinaryCollationIgnored - || rightResult.hasNonUTF8BinaryCollationIgnored) { - resolvedPredicate = AlwaysTrue.ALWAYS_TRUE; - hasNonUTF8BinaryCollation = true; - } else { - resolvedPredicate = dataSkippingPredicate; - hasNonUTF8BinaryCollation = false; - } - break; - - case "=": - case "<": - case "<=": - case ">": - case ">=": - case "IS NOT DISTINCT FROM": - hasNonUTF8BinaryCollation = - dataSkippingPredicate - .getCollationIdentifier() - .map(id -> !id.isSparkUTF8BinaryCollation()) - .orElse(false); - resolvedPredicate = - hasNonUTF8BinaryCollation ? AlwaysTrue.ALWAYS_TRUE : dataSkippingPredicate; - break; - } - - ExpressionTransformResult result = visitScalarExpression(resolvedPredicate); - return new ExpressionTransformResult( - result.expression, result.outputType, hasNonUTF8BinaryCollation); - } - - @Override - ExpressionTransformResult visitPartitionPredicate(PartitionPredicate partitionPredicate) { - shouldIgnoreNonUTF8BinaryComparisons = true; - return visitScalarExpression(partitionPredicate); - } - private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { checkArgument( @@ -495,45 +357,23 @@ private Predicate validateIsPredicate( return (Predicate) result.expression; } - private void validateChildrenAreDataSkippingPredicates( - DataSkippingPredicate dataSkippingPredicate) { - for (Expression child : dataSkippingPredicate.getChildren()) { - checkArgument( - child instanceof DataSkippingPredicate, - "%s: expected children to be DataSkippingPredicate but got %s", - dataSkippingPredicate, - child); - } - } - - private Tuple2 transformBinaryComparator(Predicate predicate) { + private Expression transformBinaryComparator(Predicate predicate) { ExpressionTransformResult leftResult = visit(getLeft(predicate)); ExpressionTransformResult rightResult = visit(getRight(predicate)); Expression left = leftResult.expression; Expression right = rightResult.expression; if (predicate.getCollationIdentifier().isPresent()) { + CollationIdentifier collationIdentifier = predicate.getCollationIdentifier().get(); + checkIsUTF8BinaryCollation(predicate, collationIdentifier); + for (DataType dataType : Arrays.asList(leftResult.outputType, rightResult.outputType)) { checkIsStringType( dataType, predicate, format("Predicate %s expects STRING type inputs", predicate.getName())); } - - CollationIdentifier collationIdentifier = predicate.getCollationIdentifier().get(); - if (!shouldIgnoreNonUTF8BinaryComparisons) { - checkIsUTF8BinaryCollation(predicate, collationIdentifier); - return new Tuple2( - new Predicate(predicate.getName(), left, right, collationIdentifier), false); - } else { - if (collationIdentifier.isSparkUTF8BinaryCollation()) { - return new Tuple2( - new Predicate(predicate.getName(), left, right, collationIdentifier), false); - } else { - // Ignore non-UTF8_BINARY collation in partition predicates - return new Tuple2(AlwaysTrue.ALWAYS_TRUE, true); - } - } + return new Predicate(predicate.getName(), left, right, collationIdentifier); } if (!leftResult.outputType.equivalent(rightResult.outputType)) { @@ -550,13 +390,7 @@ private Tuple2 transformBinaryComparator(Predicate predicat throw unsupportedExpressionException(predicate, msg); } } - - if (leftResult.hasNonUTF8BinaryCollationIgnored - || rightResult.hasNonUTF8BinaryCollationIgnored) { - return new Tuple2(AlwaysTrue.ALWAYS_TRUE, true); - } else { - return new Tuple2(new Predicate(predicate.getName(), left, right), false); - } + return new Predicate(predicate.getName(), left, right); } } @@ -892,16 +726,6 @@ ColumnVector visitStartsWith(Predicate startsWith) { startsWith.getChildren().stream().map(this::visit).collect(toList())); } - @Override - ColumnVector visitDataSkippingPredicate(DataSkippingPredicate dataSkippingPredicate) { - throw new UnsupportedOperationException("DataSkippingPredicate expression is not expected."); - } - - @Override - ColumnVector visitPartitionPredicate(PartitionPredicate partitionPredicate) { - throw new UnsupportedOperationException("PartitionPredicate expression is not expected."); - } - /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index 7558450c61d..11f41df47c5 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -22,6 +22,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.exceptions.UnsupportedPredicateWithCollation; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.Literal; import io.delta.kernel.expressions.Predicate; @@ -532,11 +533,7 @@ static Predicate createPredicate( static void checkIsUTF8BinaryCollation( Predicate predicate, CollationIdentifier collationIdentifier) { if (!collationIdentifier.isSparkUTF8BinaryCollation()) { - String msg = - format( - "Unsupported collation: \"%s\". Default Engine supports just" + " \"%s\" collation.", - collationIdentifier, CollationIdentifier.SPARK_UTF8_BINARY); - throw unsupportedExpressionException(predicate, msg); + throw new UnsupportedPredicateWithCollation(); } } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index 4ef82a3fc1a..e3e784694de 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -21,8 +21,6 @@ import static java.util.stream.Collectors.joining; import io.delta.kernel.expressions.*; -import io.delta.kernel.internal.skipping.DataSkippingPredicate; -import io.delta.kernel.internal.skipping.PartitionPredicate; import io.delta.kernel.types.CollationIdentifier; import java.util.List; import java.util.Locale; @@ -74,17 +72,9 @@ abstract class ExpressionVisitor { abstract R visitStartsWith(Predicate predicate); - abstract R visitDataSkippingPredicate(DataSkippingPredicate predicate); - - abstract R visitPartitionPredicate(PartitionPredicate predicate); - final R visit(Expression expression) { if (expression instanceof PartitionValueExpression) { return visitPartitionValue((PartitionValueExpression) expression); - } else if (expression instanceof PartitionPredicate) { - return visitPartitionPredicate((PartitionPredicate) expression); - } else if (expression instanceof DataSkippingPredicate) { - return visitDataSkippingPredicate((DataSkippingPredicate) expression); } else if (expression instanceof ScalarExpression) { return visitScalarExpression((ScalarExpression) expression); } else if (expression instanceof Literal) { @@ -99,7 +89,7 @@ final R visit(Expression expression) { String.format("Expression %s is not supported.", expression)); } - R visitScalarExpression(ScalarExpression expression) { + private R visitScalarExpression(ScalarExpression expression) { List children = expression.getChildren(); String name = expression.getName().toUpperCase(Locale.ENGLISH); Optional collationIdentifier = Optional.empty();