diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 0251cb28d5e..075f30e3975 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -366,7 +366,7 @@ private CloseableIterator applyDataSkipping( // pruning it after is much simpler StructType prunedStatsSchema = DataSkippingUtils.pruneStatsSchema( - getStatsSchema(metadata.getDataSchema()), dataSkippingFilter.getReferencedCols()); + getStatsSchema(metadata.getDataSchema(), dataSkippingFilter.getReferencedCollations()), dataSkippingFilter.getReferencedCols()); // Skipping happens in two steps: // 1. The predicate produces false for any file whose stats prove we can safely skip it. A diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java index b5af6d60d64..a689810aebb 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java @@ -18,6 +18,8 @@ import io.delta.kernel.expressions.Column; import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.CollationIdentifier; + import java.util.*; /** A {@link Predicate} with a set of columns referenced by the expression. */ @@ -26,6 +28,9 @@ public class DataSkippingPredicate extends Predicate { /** Set of {@link Column}s referenced by the predicate or any of its child expressions */ private final Set referencedCols; + /** Set of {@link CollationIdentifier}s referenced by this predicate or any of its child expressions */ + private final Set collationIdentifiers; + /** * @param name the predicate name * @param children list of expressions that are input to this predicate. @@ -35,6 +40,24 @@ public class DataSkippingPredicate extends Predicate { DataSkippingPredicate(String name, List children, Set referencedCols) { super(name, children); this.referencedCols = Collections.unmodifiableSet(referencedCols); + this.collationIdentifiers = Collections.unmodifiableSet(new HashSet<>()); + } + + /** + * @param name the predicate name + * @param children list of expressions that are input to this predicate. + * @param collationIdentifier collation identifier used for this predicate + * @param referencedCols set of columns referenced by this predicate or any of its child + * expressions + */ + DataSkippingPredicate( + String name, + List children, + CollationIdentifier collationIdentifier, + Set referencedCols) { + super(name, children, collationIdentifier); + this.referencedCols = Collections.unmodifiableSet(referencedCols); + this.collationIdentifiers = Collections.singleton(collationIdentifier); } /** @@ -46,18 +69,32 @@ public class DataSkippingPredicate extends Predicate { * @param right right input to this predicate */ DataSkippingPredicate(String name, DataSkippingPredicate left, DataSkippingPredicate right) { - this( - name, - Arrays.asList(left, right), - new HashSet() { - { - addAll(left.getReferencedCols()); - addAll(right.getReferencedCols()); - } - }); + super(name, Arrays.asList(left, right)); + this.referencedCols = + Collections.unmodifiableSet( + new HashSet() { + { + addAll(left.getReferencedCols()); + addAll(right.getReferencedCols()); + } + }); + this.collationIdentifiers = + Collections.unmodifiableSet( + new HashSet() { + { + addAll(left.getReferencedCollations()); + addAll(right.getReferencedCollations()); + } + }); } + /** @return set of columns referenced by this predicate or any of its child expressions */ public Set getReferencedCols() { return referencedCols; } + + /** @return set of collation identifiers referenced by this predicate or any of its child expressions */ + public Set getReferencedCollations() { + return collationIdentifiers; + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java index e127dc57c5b..509c2464726 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java @@ -27,6 +27,7 @@ import io.delta.kernel.engine.Engine; import io.delta.kernel.expressions.*; import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.types.CollationIdentifier; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; import java.util.*; @@ -259,14 +260,15 @@ private static Optional constructDataSkippingFilter( case "IS NOT DISTINCT FROM": Expression left = getLeft(dataFilters); Expression right = getRight(dataFilters); + Optional collationIdentifier = dataFilters.getCollationIdentifier(); if (left instanceof Column && right instanceof Literal) { Column leftCol = (Column) left; Literal rightLit = (Literal) right; - if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) + if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol, collationIdentifier.isPresent()) && schemaHelper.isSkippingEligibleLiteral(rightLit)) { return constructComparatorDataSkippingFilters( - dataFilters.getName(), leftCol, rightLit, schemaHelper); + dataFilters.getName(), leftCol, rightLit, schemaHelper, collationIdentifier); } } else if (right instanceof Column && left instanceof Literal) { return constructDataSkippingFilter(reverseComparatorFilter(dataFilters), schemaHelper); @@ -284,7 +286,7 @@ private static Optional constructDataSkippingFilter( /** Construct the skipping predicate for a given comparator */ private static Optional constructComparatorDataSkippingFilters( - String comparator, Column leftCol, Literal rightLit, StatsSchemaHelper schemaHelper) { + String comparator, Column leftCol, Literal rightLit, StatsSchemaHelper schemaHelper, Optional collationIdentifier) { switch (comparator.toUpperCase(Locale.ROOT)) { @@ -295,35 +297,35 @@ private static Optional constructComparatorDataSkippingFi new DataSkippingPredicate( "AND", constructBinaryDataSkippingPredicate( - "<=", schemaHelper.getMinColumn(leftCol), rightLit), + "<=", schemaHelper.getMinColumn(leftCol, collationIdentifier), rightLit, collationIdentifier), constructBinaryDataSkippingPredicate( - ">=", schemaHelper.getMaxColumn(leftCol), rightLit))); + ">=", schemaHelper.getMaxColumn(leftCol, collationIdentifier), rightLit, collationIdentifier))); // Match any file whose min is less than the requested upper bound. case "<": return Optional.of( constructBinaryDataSkippingPredicate( - "<", schemaHelper.getMinColumn(leftCol), rightLit)); + "<", schemaHelper.getMinColumn(leftCol, collationIdentifier), rightLit, collationIdentifier)); // Match any file whose min is less than or equal to the requested upper bound case "<=": return Optional.of( constructBinaryDataSkippingPredicate( - "<=", schemaHelper.getMinColumn(leftCol), rightLit)); + "<=", schemaHelper.getMinColumn(leftCol, collationIdentifier), rightLit, collationIdentifier)); // Match any file whose max is larger than the requested lower bound. case ">": return Optional.of( constructBinaryDataSkippingPredicate( - ">", schemaHelper.getMaxColumn(leftCol), rightLit)); + ">", schemaHelper.getMaxColumn(leftCol, collationIdentifier), rightLit, collationIdentifier)); // Match any file whose max is larger than or equal to the requested lower bound. case ">=": return Optional.of( constructBinaryDataSkippingPredicate( - ">=", schemaHelper.getMaxColumn(leftCol), rightLit)); + ">=", schemaHelper.getMaxColumn(leftCol, collationIdentifier), rightLit, collationIdentifier)); case "IS NOT DISTINCT FROM": - return constructDataSkippingFilter(rewriteEqualNullSafe(leftCol, rightLit), schemaHelper); + return constructDataSkippingFilter(rewriteEqualNullSafe(leftCol, rightLit, collationIdentifier), schemaHelper); default: throw new IllegalArgumentException( String.format("Unsupported comparator expression %s", comparator)); @@ -336,11 +338,19 @@ private static Optional constructComparatorDataSkippingFi * Literal}. */ private static DataSkippingPredicate constructBinaryDataSkippingPredicate( - String exprName, Tuple2> colExpr, Literal lit) { + String exprName, Tuple2> colExpr, Literal lit, Optional collationIdentifier) { Column column = colExpr._1; Expression adjColExpr = colExpr._2.isPresent() ? colExpr._2.get() : column; - return new DataSkippingPredicate( - exprName, Arrays.asList(adjColExpr, lit), Collections.singleton(column)); + if (collationIdentifier.isPresent()) { + return new DataSkippingPredicate( + exprName, + Arrays.asList(adjColExpr, lit), + collationIdentifier.get(), + Collections.singleton(column)); + } else { + return new DataSkippingPredicate( + exprName, Arrays.asList(adjColExpr, lit), Collections.singleton(column)); + } } private static final Map REVERSE_COMPARATORS = @@ -356,15 +366,17 @@ private static DataSkippingPredicate constructBinaryDataSkippingPredicate( }; private static Predicate reverseComparatorFilter(Predicate predicate) { - return new Predicate( + return createPredicate( REVERSE_COMPARATORS.get(predicate.getName().toUpperCase(Locale.ROOT)), getRight(predicate), - getLeft(predicate)); + getLeft(predicate), + predicate.getCollationIdentifier()); } /** Construct the skipping predicate for a NOT expression child if possible */ private static Optional constructNotDataSkippingFilters( Predicate childPredicate, StatsSchemaHelper schemaHelper) { + Optional collationIdentifier = childPredicate.getCollationIdentifier(); switch (childPredicate.getName().toUpperCase(Locale.ROOT)) { // Use deMorgan's law to push the NOT past the AND. This is safe even with SQL // tri-valued logic (see below), and is desirable because we cannot generally push @@ -423,29 +435,29 @@ private static Optional constructNotDataSkippingFilters( new DataSkippingPredicate( "OR", constructBinaryDataSkippingPredicate( - "<", schemaHelper.getMinColumn(leftColumn), rightLiteral), + "<", schemaHelper.getMinColumn(leftColumn, collationIdentifier), rightLiteral, collationIdentifier), constructBinaryDataSkippingPredicate( - ">", schemaHelper.getMaxColumn(leftColumn), rightLiteral))); + ">", schemaHelper.getMaxColumn(leftColumn, collationIdentifier), rightLiteral, collationIdentifier))); }); case "<": return constructDataSkippingFilter( - new Predicate(">=", childPredicate.getChildren()), schemaHelper); + createPredicate(">=", childPredicate.getChildren(), collationIdentifier), schemaHelper); case "<=": return constructDataSkippingFilter( - new Predicate(">", childPredicate.getChildren()), schemaHelper); + createPredicate(">", childPredicate.getChildren(), collationIdentifier), schemaHelper); case ">": return constructDataSkippingFilter( - new Predicate("<=", childPredicate.getChildren()), schemaHelper); + createPredicate("<=", childPredicate.getChildren(), collationIdentifier), schemaHelper); case ">=": return constructDataSkippingFilter( - new Predicate("<", childPredicate.getChildren()), schemaHelper); + createPredicate("<", childPredicate.getChildren(), collationIdentifier), schemaHelper); case "IS NOT DISTINCT FROM": return constructDataSkippingFiltersForNotEqual( childPredicate, schemaHelper, (leftColumn, rightLiteral) -> constructDataSkippingFilter( - new Predicate("NOT", rewriteEqualNullSafe(leftColumn, rightLiteral)), + new Predicate("NOT", rewriteEqualNullSafe(leftColumn, rightLiteral, collationIdentifier)), schemaHelper)); case "NOT": // Remove redundant pairs of NOT @@ -525,12 +537,12 @@ private static String[] appendArray(String[] arr, String appendElem) { * Rewrite `EqualNullSafe(a, NotNullLiteral)` as `And(IsNotNull(a), EqualTo(a, NotNullLiteral))` * and rewrite `EqualNullSafe(a, null)` as `IsNull(a)` */ - private static Predicate rewriteEqualNullSafe(Column leftCol, Literal rightLit) { + private static Predicate rewriteEqualNullSafe(Column leftCol, Literal rightLit, Optional collationIdentifier) { if (rightLit.getValue() == null) { return new Predicate("IS_NULL", leftCol); } return new Predicate( - "AND", new Predicate("IS_NOT_NULL", leftCol), new Predicate("=", leftCol, rightLit)); + "AND", new Predicate("IS_NOT_NULL", leftCol), createPredicate("=", leftCol, rightLit, collationIdentifier)); } /** Helper method for building DataSkippingPredicate for NOT =/IS NOT DISTINCT FROM */ @@ -544,15 +556,16 @@ private static Optional constructDataSkippingFiltersForNo "Expects predicate to be = or IS NOT DISTINCT FROM"); Expression leftChild = getLeft(equalPredicate); Expression rightChild = getRight(equalPredicate); + Optional collationIdentifier = equalPredicate.getCollationIdentifier(); if (rightChild instanceof Column && leftChild instanceof Literal) { return constructDataSkippingFilter( - new Predicate("NOT", new Predicate(equalPredicate.getName(), rightChild, leftChild)), + new Predicate("NOT", createPredicate(equalPredicate.getName(), rightChild, leftChild, collationIdentifier)), schemaHelper); } if (leftChild instanceof Column && rightChild instanceof Literal) { Column leftCol = (Column) leftChild; Literal rightLit = (Literal) rightChild; - if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) + if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol, collationIdentifier.isPresent()) && schemaHelper.isSkippingEligibleLiteral(rightLit)) { return buildDataSkippingPredicateFunc.apply(leftCol, rightLit); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java index b97595e4a6b..1c75ca83b7f 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java @@ -24,6 +24,7 @@ import io.delta.kernel.expressions.ScalarExpression; import io.delta.kernel.internal.util.Tuple2; import io.delta.kernel.types.*; + import java.util.*; import java.util.stream.Collectors; @@ -49,22 +50,27 @@ public class StatsSchemaHelper { public static final String MAX = "maxValues"; public static final String NULL_COUNT = "nullCount"; public static final String TIGHT_BOUNDS = "tightBounds"; + public static final String STATS_WITH_COLLATION = "statsWithCollation"; /** * Returns true if the given literal is skipping-eligible. Delta tracks min/max stats for a * limited set of data types and only literals of those types are skipping eligible. */ public static boolean isSkippingEligibleLiteral(Literal literal) { - return isSkippingEligibleDataType(literal.getDataType()); + return isSkippingEligibleDataType(literal.getDataType(), false); } /** Returns true if the given data type is eligible for MIN/MAX data skipping. */ - public static boolean isSkippingEligibleDataType(DataType dataType) { - return SKIPPING_ELIGIBLE_TYPE_NAMES.contains(dataType.toString()) - || - // DecimalType is eligible but since its string includes scale + precision it needs to - // be matched separately - dataType instanceof DecimalType; + public static boolean isSkippingEligibleDataType(DataType dataType, boolean isCollatedSkipping) { + if (isCollatedSkipping) { + return dataType instanceof StringType; + } else { + return SKIPPING_ELIGIBLE_TYPE_NAMES.contains(dataType.toString()) + || + // DecimalType is eligible but since its string includes scale + precision it needs to + // be matched separately + dataType instanceof DecimalType; + } } /** @@ -101,10 +107,10 @@ public static boolean isSkippingEligibleDataType(DataType dataType) { * | |-- tightBounds: boolean (nullable = true) * */ - public static StructType getStatsSchema(StructType dataSchema) { + public static StructType getStatsSchema(StructType dataSchema, Set collationIdentifiers) { StructType statsSchema = new StructType().add(NUM_RECORDS, LongType.LONG, true); - StructType minMaxStatsSchema = getMinMaxStatsSchema(dataSchema); + StructType minMaxStatsSchema = getMinMaxStatsSchema(dataSchema, false); if (minMaxStatsSchema.length() > 0) { statsSchema = statsSchema.add(MIN, minMaxStatsSchema, true).add(MAX, minMaxStatsSchema, true); } @@ -116,6 +122,11 @@ public static StructType getStatsSchema(StructType dataSchema) { statsSchema = statsSchema.add(TIGHT_BOUNDS, BooleanType.BOOLEAN, true); + StructType collatedMinMaxStatsSchema = getCollatedStatsSchema(dataSchema, collationIdentifiers); + if (collatedMinMaxStatsSchema.length() > 0) { + statsSchema = statsSchema.add(STATS_WITH_COLLATION, collatedMinMaxStatsSchema, true); + } + return statsSchema; } @@ -153,15 +164,18 @@ public StatsSchemaHelper(StructType dataSchema) { * that stores the MIN values for the provided logical column. * * @param column the logical column name. + * @param collationIdentifier optional collation identifier if the min column is from a * @return a tuple of the MIN column and an optional adjustment expression. */ - public Tuple2> getMinColumn(Column column) { + public Tuple2> getMinColumn( + Column column, Optional collationIdentifier) { checkArgument( - isSkippingEligibleMinMaxColumn(column), - "%s is not a valid min column for data schema %s", + isSkippingEligibleMinMaxColumn(column, collationIdentifier.isPresent()), + "%s is not a valid min column%s for data schema %s", column, + collationIdentifier.isPresent() ? (" for collation " + collationIdentifier) : "", dataSchema); - return new Tuple2<>(getStatsColumn(column, MIN), Optional.empty()); + return new Tuple2<>(getStatsColumn(column, MIN, collationIdentifier), Optional.empty()); } /** @@ -170,16 +184,19 @@ public Tuple2> getMinColumn(Column column) { * that stores the MAX values for the provided logical column. * * @param column the logical column name. + * @param collationIdentifier optional collation identifier if getting a collated stats column. * @return a tuple of the MAX column and an optional adjustment expression. */ - public Tuple2> getMaxColumn(Column column) { + public Tuple2> getMaxColumn( + Column column, Optional collationIdentifier) { checkArgument( - isSkippingEligibleMinMaxColumn(column), - "%s is not a valid min column for data schema %s", + isSkippingEligibleMinMaxColumn(column, collationIdentifier.isPresent()), + "%s is not a valid min column%s for data schema %s", column, + collationIdentifier.isPresent() ? (" for collation " + collationIdentifier) : "", dataSchema); DataType dataType = logicalToDataType.get(column); - Column maxColumn = getStatsColumn(column, MAX); + Column maxColumn = getStatsColumn(column, MAX, collationIdentifier); // If this is a column of type Timestamp or TimestampNTZ // compensate for the truncation from microseconds to milliseconds @@ -206,7 +223,7 @@ public Column getNullCountColumn(Column column) { "%s is not a valid null_count column for data schema %s", column, dataSchema); - return getStatsColumn(column, NULL_COUNT); + return getStatsColumn(column, NULL_COUNT, Optional.empty()); } /** Returns the NUM_RECORDS column in the statistic schema */ @@ -218,9 +235,9 @@ public Column getNumRecordsColumn() { * Returns true if the given column is skipping-eligible using min/max statistics. This means the * column exists, is a leaf column, and is of a skipping-eligible data-type. */ - public boolean isSkippingEligibleMinMaxColumn(Column column) { + public boolean isSkippingEligibleMinMaxColumn(Column column, boolean isCollatedSkipping) { return logicalToDataType.containsKey(column) - && isSkippingEligibleDataType(logicalToDataType.get(column)); + && isSkippingEligibleDataType(logicalToDataType.get(column), isCollatedSkipping); } /** @@ -256,22 +273,45 @@ public boolean isSkippingEligibleNullCountColumn(Column column) { * 1) replace logical names with physical names 2) set nullable=true 3) only keep stats eligible * fields (i.e. don't include fields with isSkippingEligibleDataType=false) */ - private static StructType getMinMaxStatsSchema(StructType dataSchema) { + private static StructType getMinMaxStatsSchema(StructType dataSchema, boolean isCollatedSkipping) { List fields = new ArrayList<>(); for (StructField field : dataSchema.fields()) { - if (isSkippingEligibleDataType(field.getDataType())) { + if (isSkippingEligibleDataType(field.getDataType(), isCollatedSkipping)) { fields.add(new StructField(getPhysicalName(field), field.getDataType(), true)); } else if (field.getDataType() instanceof StructType) { fields.add( new StructField( getPhysicalName(field), - getMinMaxStatsSchema((StructType) field.getDataType()), + getMinMaxStatsSchema((StructType) field.getDataType(), isCollatedSkipping), true)); } } return new StructType(fields); } + /** + * Given a data schema returns the expected schema for a STATS_WITH_COLLATION statistics column. + * This means 1) replace logical names with physical names 2) set nullable=true 3) only keep + * string columns + */ + private static StructType getCollatedStatsSchema( + StructType dataSchema, Set collationIdentifiers) { + StructType statsWithCollation = new StructType(); + StructType collatedMinMaxStatsSchema = getMinMaxStatsSchema(dataSchema, true); + for (CollationIdentifier collationIdentifier : collationIdentifiers) { + if (collatedMinMaxStatsSchema.length() > 0) { + statsWithCollation = + statsWithCollation.add( + collationIdentifier.toString(), + new StructType() + .add(MIN, collatedMinMaxStatsSchema, true) + .add(MAX, collatedMinMaxStatsSchema, true), + true); + } + } + return statsWithCollation; + } + /** * Given a data schema returns the expected schema for a null_count statistics column. This means * 1) replace logical names with physical names 2) set nullable=true 3) use LongType for all @@ -301,13 +341,23 @@ private static StructType getNullCountSchema(StructType dataSchema) { * Given a logical column and a stats type returns the corresponding column in the statistics * schema */ - private Column getStatsColumn(Column column, String statType) { + private Column getStatsColumn( + Column column, String statType, Optional collationIdentifier) { checkArgument( logicalToPhysicalColumn.containsKey(column), "%s is not a valid leaf column for data schema: %s", column, dataSchema); - return getChildColumn(logicalToPhysicalColumn.get(column), statType); + Column physicalColumn = logicalToPhysicalColumn.get(column); + // Use binary stats if collation is not specified or if it is the default Spark collation. + if (collationIdentifier.isPresent() + && collationIdentifier.get() != CollationIdentifier.SPARK_UTF8_BINARY) { + return getChildColumn( + physicalColumn, + Arrays.asList(statType, collationIdentifier.get().toString(), STATS_WITH_COLLATION)); + } else { + return getChildColumn(physicalColumn, statType); + } } /** @@ -344,6 +394,14 @@ private static Column getChildColumn(Column column, String parentName) { return new Column(prependArray(column.getNames(), parentName)); } + /** Returns the provided column as a child column nested under {@code ancestorPath} */ + private static Column getChildColumn(Column column, List ancestorPath) { + for (String name : ancestorPath) {; + column = getChildColumn(column, name); + } + return column; + } + /** * Given an array {@code names} and a string element {@code preElem} return a new array with * {@code preElem} inserted at the beginning diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java index 2e82fffea4c..f94b9b2024d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/ExpressionUtils.java @@ -19,7 +19,10 @@ import io.delta.kernel.expressions.Expression; import io.delta.kernel.expressions.Predicate; +import io.delta.kernel.types.CollationIdentifier; +import java.util.Arrays; import java.util.List; +import java.util.Optional; public class ExpressionUtils { /** Return an expression cast as a predicate, throw an error if it is not a predicate */ @@ -51,4 +54,23 @@ public static Expression getUnaryChild(Expression expression) { children.size() == 1, "%s: expected one inputs, but got %s", expression, children.size()); return children.get(0); } + + /** Utility method to create a predicate with an optional collation identifier */ + public static Predicate createPredicate( + String name, List children, Optional collationIdentifier) { + if (collationIdentifier.isPresent()) { + return new Predicate(name, children, collationIdentifier.get()); + } else { + return new Predicate(name, children); + } + } + + /** Utility method to create a binary predicate with an optional collation identifier */ + public static Predicate createPredicate( + String name, + Expression left, + Expression right, + Optional collationIdentifier) { + return createPredicate(name, Arrays.asList(left, right), collationIdentifier); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java index 607bdbfdafa..46c0114c5f5 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/PartitionUtils.java @@ -18,6 +18,7 @@ import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; import static io.delta.kernel.internal.DeltaErrors.wrapEngineException; +import static io.delta.kernel.internal.util.ExpressionUtils.createPredicate; import static io.delta.kernel.internal.util.InternalUtils.toLowerCaseSet; import static io.delta.kernel.internal.util.Preconditions.checkArgument; import static io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames; @@ -270,11 +271,12 @@ public static Tuple2 splitMetadataAndDataPredicates( */ public static Predicate rewritePartitionPredicateOnCheckpointFileSchema( Predicate predicate, Map partitionColNameToField) { - return new Predicate( + return createPredicate( predicate.getName(), predicate.getChildren().stream() .map(child -> rewriteColRefOnPartitionValuesParsed(child, partitionColNameToField)) - .collect(Collectors.toList())); + .collect(Collectors.toList()), + predicate.getCollationIdentifier()); } private static Expression rewriteColRefOnPartitionValuesParsed( @@ -319,11 +321,12 @@ private static Expression rewriteColRefOnPartitionValuesParsed( */ public static Predicate rewritePartitionPredicateOnScanFileSchema( Predicate predicate, Map partitionColMetadata) { - return new Predicate( + return createPredicate( predicate.getName(), predicate.getChildren().stream() .map(child -> rewritePartitionColumnRef(child, partitionColMetadata)) - .collect(Collectors.toList())); + .collect(Collectors.toList()), + predicate.getCollationIdentifier()); } private static Expression rewritePartitionColumnRef( diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java index 66fc5665b83..d3b47e3a041 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/SchemaUtils.java @@ -281,7 +281,7 @@ public static List casePreservingEligibleClusterColumns( List nonSkippingEligibleColumns = physicalColumnsWithTypes.stream() - .filter(tuple -> !StatsSchemaHelper.isSkippingEligibleDataType(tuple._2)) + .filter(tuple -> !StatsSchemaHelper.isSkippingEligibleDataType(tuple._2, false)) .map(tuple -> tuple._1.toString() + " : " + tuple._2) .collect(Collectors.toList()); diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala new file mode 100644 index 00000000000..737bfdc2090 --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala @@ -0,0 +1,513 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal.skipping + +import io.delta.kernel.expressions.{Column, Expression, Predicate} +import io.delta.kernel.internal.skipping.DataSkippingUtils.constructDataSkippingFilter +import io.delta.kernel.internal.skipping.StatsSchemaHelper.{MAX, MIN, STATS_WITH_COLLATION} +import io.delta.kernel.internal.util.ExpressionUtils.createPredicate +import io.delta.kernel.test.TestUtils +import io.delta.kernel.types.IntegerType.INTEGER +import io.delta.kernel.types._ +import org.scalatest.funsuite.AnyFunSuite + +import java.util.Optional +import scala.collection.JavaConverters._ + +class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils { + + def dataSkippingPredicate( + operator: String, + children: Seq[Expression], + referencedColumns: Set[Column]): DataSkippingPredicate = { + new DataSkippingPredicate(operator, children.asJava, referencedColumns.asJava) + } + + def dataSkippingPredicate( + operator: String, + left: DataSkippingPredicate, + right: DataSkippingPredicate): DataSkippingPredicate = { + new DataSkippingPredicate(operator, left, right) + } + + def dataSkippingPredicateWithCollation( + operator: String, + children: Seq[Expression], + collation: CollationIdentifier, + referencedColumns: Set[Column]): DataSkippingPredicate = { + new DataSkippingPredicate(operator, children.asJava, collation, referencedColumns.asJava) + } + + private def collatedStatsCol( + collation: CollationIdentifier, + statName: String, + fieldName: String): Column = { + new Column(Array(STATS_WITH_COLLATION, collation.toString, statName, fieldName)) + } + + /* For struct type checks for equality based on field names & data type only */ + def compareDataTypeUnordered(type1: DataType, type2: DataType): Boolean = (type1, type2) match { + case (schema1: StructType, schema2: StructType) => + val fields1 = schema1.fields().asScala.sortBy(_.getName) + val fields2 = schema2.fields().asScala.sortBy(_.getName) + if (fields1.length != fields2.length) { + false + } else { + fields1.zip(fields2).forall { case (field1: StructField, field2: StructField) => + field1.getName == field2.getName && + compareDataTypeUnordered(field1.getDataType, field2.getDataType) + } + } + case _ => + type1 == type2 + } + + def checkPruneStatsSchema( + inputSchema: StructType, + referencedCols: Set[Column], + expectedSchema: StructType): Unit = { + val prunedSchema = DataSkippingUtils.pruneStatsSchema(inputSchema, referencedCols.asJava) + assert( + compareDataTypeUnordered(expectedSchema, prunedSchema), + s"expected=$expectedSchema\nfound=$prunedSchema") + } + + test("pruneStatsSchema - multiple basic cases one level of nesting") { + val nestedField = new StructField( + "nested", + new StructType() + .add("col1", INTEGER) + .add("col2", INTEGER), + true) + val testSchema = new StructType() + .add(nestedField) + .add("top_level_col", INTEGER) + // no columns pruned + checkPruneStatsSchema( + testSchema, + Set(col("top_level_col"), nestedCol("nested.col1"), nestedCol("nested.col2")), + testSchema) + // top level column pruned + checkPruneStatsSchema( + testSchema, + Set(nestedCol("nested.col1"), nestedCol("nested.col2")), + new StructType().add(nestedField)) + // nested column only one field pruned + checkPruneStatsSchema( + testSchema, + Set(nestedCol("top_level_col"), nestedCol("nested.col1")), + new StructType() + .add("nested", new StructType().add("col1", INTEGER)) + .add("top_level_col", INTEGER)) + // nested column completely pruned + checkPruneStatsSchema( + testSchema, + Set(nestedCol("top_level_col")), + new StructType().add("top_level_col", INTEGER)) + // prune all columns + checkPruneStatsSchema( + testSchema, + Set(), + new StructType()) + } + + test("pruneStatsSchema - 3 levels of nesting") { + /* + |--level1: struct + | |--level2: struct + | |--level3: struct + | |--level_4_col: int + | |--level_3_col: int + | |--level_2_col: int + */ + val testSchema = new StructType() + .add( + "level1", + new StructType() + .add( + "level2", + new StructType() + .add( + "level3", + new StructType().add("level_4_col", INTEGER)) + .add("level_3_col", INTEGER)) + .add("level_2_col", INTEGER)) + // prune only 4th level col + checkPruneStatsSchema( + testSchema, + Set(nestedCol("level1.level2.level_3_col"), nestedCol("level1.level_2_col")), + new StructType() + .add( + "level1", + new StructType() + .add("level2", new StructType().add("level_3_col", INTEGER)) + .add("level_2_col", INTEGER))) + // prune only 3rd level column + checkPruneStatsSchema( + testSchema, + Set(nestedCol("level1.level2.level3.level_4_col"), nestedCol("level1.level_2_col")), + new StructType() + .add( + "level1", + new StructType() + .add( + "level2", + new StructType() + .add( + "level3", + new StructType().add("level_4_col", INTEGER))) + .add("level_2_col", INTEGER))) + // prune 4th and 3rd level column + checkPruneStatsSchema( + testSchema, + Set(nestedCol("level1.level_2_col")), + new StructType() + .add( + "level1", + new StructType() + .add("level_2_col", INTEGER))) + // prune all columns + checkPruneStatsSchema( + testSchema, + Set(), + new StructType()) + } + + // TODO: add tests for remaining operators + test("check constructDataSkippingFilter") { + val testCases = Seq( + // (schema, predicate, expectedDataSkippingPredicateOpt) + ( + new StructType() + .add("a", StringType.STRING) + .add("b", StringType.STRING), + createPredicate("<", col("a"), col("b"), Optional.empty[CollationIdentifier]), + None), + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", StringType.STRING), + createPredicate("<", col("a"), literal("x"), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + "<", + Seq(nestedCol(s"$MIN.a"), literal("x")), + Set(nestedCol(s"$MIN.a"))))), + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", StringType.STRING), + createPredicate("<", literal("x"), col("a"), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + ">", + Seq(nestedCol(s"$MAX.a"), literal("x")), + Set(nestedCol(s"$MAX.a"))))), + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", StringType.STRING), + createPredicate(">", col("a"), literal("x"), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + ">", + Seq(nestedCol(s"$MAX.a"), literal("x")), + Set(nestedCol(s"$MAX.a"))))), + ( + new StructType() + .add("a", IntegerType.INTEGER), + createPredicate("=", col("a"), literal(10), Optional.empty[CollationIdentifier]), + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.a"), literal(10)), + Set(nestedCol(s"$MIN.a"))), + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(10)), + Set(nestedCol(s"$MAX.a")))))), + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate("<", col("a"), literal(10), Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(10)), + Set(nestedCol(s"$MAX.a"))))), + // NOT over AND: NOT(a < 5 AND a > 10) => (max.a >= 5) OR (min.a <= 10) + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "AND", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate(">", col("a"), literal(10), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + "OR", + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))), + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.a"), literal(10)), + Set(nestedCol(s"$MIN.a")))))), + // NOT over OR: NOT(a < 5 OR a > 10) => (max.a >= 5) AND (min.a <= 10) + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "OR", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate(">", col("a"), literal(10), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))), + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.a"), literal(10)), + Set(nestedCol(s"$MIN.a")))))), + // NOT over OR with one ineligible leg: NOT(a < b OR a < 5) => NOT(a < b) AND NOT(a < 5) + // The first leg is ineligible; AND with single leg should return that leg only + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "OR", + createPredicate("<", col("a"), col("b"), Optional.empty[CollationIdentifier]), + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))))), + // NOT over AND with one ineligible leg: NOT(a < 5 AND a < b) + // => NOT(a < 5) OR NOT(a < b); since OR needs both legs, expect None + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "AND", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate("<", col("a"), col("b"), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + None), + // Double NOT elimination: NOT(NOT(a < 5)) => a < 5 => min.a < 5 + ( + new StructType() + .add("a", IntegerType.INTEGER), + new Predicate( + "NOT", + new Predicate( + "NOT", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]))), + Some(dataSkippingPredicate( + "<", + Seq(nestedCol(s"$MIN.a"), literal(5)), + Set(nestedCol(s"$MIN.a"))))), + // Cross-column case: NOT(a < 5 OR b > 7) => (max.a >= 5) AND (min.b <= 7) + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", IntegerType.INTEGER), + new Predicate( + "NOT", + createPredicate( + "OR", + createPredicate("<", col("a"), literal(5), Optional.empty[CollationIdentifier]), + createPredicate(">", col("b"), literal(7), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier])), + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicate( + ">=", + Seq(nestedCol(s"$MAX.a"), literal(5)), + Set(nestedCol(s"$MAX.a"))), + dataSkippingPredicate( + "<=", + Seq(nestedCol(s"$MIN.b"), literal(7)), + Set(nestedCol(s"$MIN.b"))))))) + + testCases.foreach { case (schema, predicate, expectedDataSkippingPredicateOpt) => + val dataSkippingPredicateOpt = optionalToScala(constructDataSkippingFilter(predicate, schema)) + (dataSkippingPredicateOpt, expectedDataSkippingPredicateOpt) match { + case (Some(dataSkippingPredicate), Some(expectedDataSkippingPredicate)) => + assert(dataSkippingPredicate == expectedDataSkippingPredicate) + case (None, None) => // pass + case _ => + fail(s"Expected $expectedDataSkippingPredicateOpt, found $dataSkippingPredicateOpt") + } + } + } + + test("check constructDataSkippingFilter with collations") { + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE") + + val testCases = Seq( + // (schema, predicate, expectedDataSkippingPredicateOpt) + // Ineligible: both sides are columns + ( + new StructType() + .add("a", StringType.STRING) + .add("b", StringType.STRING), + createPredicate("<", col("a"), col("b"), Optional.of(utf8Lcase)), + None), + // Eligible: a < "m" with collation -> min(a, collation) < "m" + ( + new StructType() + .add("a", StringType.STRING) + .add("b", StringType.STRING), + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase)), { + val minA = collatedStatsCol(utf8Lcase, MIN, "a") + Some(dataSkippingPredicateWithCollation( + "<", + Seq(minA, literal("m")), + utf8Lcase, + Set(minA))) + }), + // Reversed comparator: "m" < a -> max(a, collation) > "m" + ( + new StructType() + .add("a", StringType.STRING), + createPredicate("<", literal("m"), col("a"), Optional.of(utf8Lcase)), { + val maxA = collatedStatsCol(utf8Lcase, MAX, "a") + Some(dataSkippingPredicateWithCollation( + ">", + Seq(maxA, literal("m")), + utf8Lcase, + Set(maxA))) + }), + // Direct ">": a > "m" -> max(a, collation) > "m" + ( + new StructType() + .add("a", StringType.STRING), + createPredicate(">", col("a"), literal("m"), Optional.of(utf8Lcase)), { + val maxA = collatedStatsCol(utf8Lcase, MAX, "a") + Some(dataSkippingPredicateWithCollation( + ">", + Seq(maxA, literal("m")), + utf8Lcase, + Set(maxA))) + }), + // Equality + ( + new StructType() + .add("a", StringType.STRING), + createPredicate("=", col("a"), literal("abc"), Optional.of(unicode)), { + val minA = collatedStatsCol(unicode, MIN, "a") + val maxA = collatedStatsCol(unicode, MAX, "a") + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicateWithCollation("<=", Seq(minA, literal("abc")), unicode, Set(minA)), + dataSkippingPredicateWithCollation( + ">=", + Seq(maxA, literal("abc")), + unicode, + Set(maxA)))) + }), + // NOT over comparator: NOT(a < "m") -> max(a, collation) >= "m" + ( + new StructType() + .add("a", StringType.STRING), + new Predicate( + "NOT", + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase))), { + val maxA = collatedStatsCol(utf8Lcase, MAX, "a") + Some(dataSkippingPredicateWithCollation( + ">=", + Seq(maxA, literal("m")), + utf8Lcase, + Set(maxA))) + }), + // NOT over AND + // NOT(a < "m" AND a > "t") => (max.a >= "m") OR (min.a <= "t") + ( + new StructType() + .add("a", StringType.STRING), + new Predicate( + "NOT", + createPredicate( + "AND", + createPredicate("<", col("a"), literal("m"), Optional.of(unicode)), + createPredicate(">", col("a"), literal("t"), Optional.of(utf8Lcase)), + Optional.empty[CollationIdentifier])), { + val unicodeMaxA = collatedStatsCol(unicode, MAX, "a") + val utf8LcaseMinA = collatedStatsCol(utf8Lcase, MIN, "a") + Some(dataSkippingPredicate( + "OR", + dataSkippingPredicateWithCollation( + ">=", + Seq(unicodeMaxA, literal("m")), + unicode, + Set(unicodeMaxA)), + dataSkippingPredicateWithCollation( + "<=", + Seq(utf8LcaseMinA, literal("t")), + utf8Lcase, + Set(utf8LcaseMinA)))) + }), + // AND(a < "m" COLLATE UTF8_LCASE, b < 1) + ( + new StructType() + .add("a", StringType.STRING) + .add("b", IntegerType.INTEGER), + createPredicate( + "AND", + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase)), + createPredicate("<", col("b"), literal(1), Optional.empty[CollationIdentifier]), + Optional.empty[CollationIdentifier]), { + val minA = collatedStatsCol(utf8Lcase, MIN, "a") + val minB = nestedCol(s"$MIN.b") + Some(dataSkippingPredicate( + "AND", + dataSkippingPredicateWithCollation("<", Seq(minA, literal("m")), utf8Lcase, Set(minA)), + dataSkippingPredicate("<", Seq(minB, literal(1)), Set(minB)))) + }), + // Ineligible: non-string column with collation + ( + new StructType() + .add("a", IntegerType.INTEGER), + createPredicate("<", col("a"), literal("m"), Optional.of(utf8Lcase)), + None)) + + testCases.foreach { case (schema, predicate, expectedDataSkippingPredicateOpt) => + val dataSkippingPredicateOpt = optionalToScala(constructDataSkippingFilter(predicate, schema)) + (dataSkippingPredicateOpt, expectedDataSkippingPredicateOpt) match { + case (Some(dataSkippingPredicate), Some(expectedDataSkippingPredicate)) => + assert(dataSkippingPredicate == expectedDataSkippingPredicate) + case (None, None) => // pass + case _ => + fail(s"Expected $expectedDataSkippingPredicateOpt, found $dataSkippingPredicateOpt") + } + } + } +} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala new file mode 100644 index 00000000000..49fc0976fba --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala @@ -0,0 +1,432 @@ +/* + * Copyright (2025) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal.skipping + +import scala.collection.JavaConverters.setAsJavaSetConverter + +import io.delta.kernel.types.{ArrayType, BinaryType, BooleanType, ByteType, CollationIdentifier, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} + +import org.scalatest.funsuite.AnyFunSuite + +class StatsSchemaHelperSuite extends AnyFunSuite { + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE") + + test("check getStatsSchema for supported data types") { + val testCases = Seq( + ( + new StructType().add("a", IntegerType.INTEGER), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("a", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("b", StringType.STRING), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("b", StringType.STRING, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("b", StringType.STRING, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("b", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("c", ByteType.BYTE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("c", ByteType.BYTE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("c", ByteType.BYTE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("c", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("d", ShortType.SHORT), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("d", ShortType.SHORT, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("d", ShortType.SHORT, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("d", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("e", LongType.LONG), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("f", FloatType.FLOAT), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("f", FloatType.FLOAT, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("f", FloatType.FLOAT, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("f", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("g", DoubleType.DOUBLE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("g", DoubleType.DOUBLE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("g", DoubleType.DOUBLE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("g", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("h", DateType.DATE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("h", DateType.DATE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("h", DateType.DATE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("h", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("i", TimestampType.TIMESTAMP), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("i", TimestampType.TIMESTAMP, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("i", TimestampType.TIMESTAMP, true), + true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("i", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ, true), + true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("j", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("k", new DecimalType(20, 5)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("k", new DecimalType(20, 5), true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("k", new DecimalType(20, 5), true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("k", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true))) + + testCases.foreach { case (dataSchema, expectedStatsSchema) => + val statsSchema = StatsSchemaHelper.getStatsSchema( + dataSchema, + Set.empty[CollationIdentifier].asJava) + assert( + statsSchema == expectedStatsSchema, + s"Stats schema mismatch for data schema: $dataSchema") + } + } + + test("check getStatsSchema with mix of supported and unsupported data types") { + val testCases = Seq( + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", BinaryType.BINARY) + .add("c", new ArrayType(LongType.LONG, true)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("a", IntegerType.INTEGER, true), true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType() + .add( + "s", + new StructType() + .add("s1", StringType.STRING) + .add("s2", BooleanType.BOOLEAN)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("s", new StructType().add("s1", StringType.STRING, true), true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("s", new StructType().add("s1", StringType.STRING, true), true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("s1", LongType.LONG, true) + .add("s2", LongType.LONG, true), + true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + // Un-nested array/map alongside a supported type + ( + new StructType() + .add("arr", new ArrayType(IntegerType.INTEGER, true)) + .add("mp", new MapType(StringType.STRING, LongType.LONG, true)) + .add("z", DoubleType.DOUBLE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("z", DoubleType.DOUBLE, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("z", DoubleType.DOUBLE, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true) + .add("z", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + // Nested array/map inside a struct; empty struct preserved in min/max + ( + new StructType() + .add( + "s", + new StructType() + .add("arr", new ArrayType(StringType.STRING, true)) + .add("mp", new MapType(IntegerType.INTEGER, StringType.STRING, true))) + .add("k", StringType.STRING), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("s", new StructType(), true) + .add("k", StringType.STRING, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("s", new StructType(), true) + .add("k", StringType.STRING, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true), + true) + .add("k", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true))) + + testCases.foreach { case (dataSchema, expectedStatsSchema) => + val statsSchema = StatsSchemaHelper.getStatsSchema( + dataSchema, + Set.empty[CollationIdentifier].asJava) + assert( + statsSchema == expectedStatsSchema, + s"Stats schema mismatch for data schema: $dataSchema") + } + } + + test("check getStatsSchema with collations - un-nested mix") { + val dataSchema = new StructType() + .add("a", StringType.STRING) + .add("b", IntegerType.INTEGER) + .add("c", BinaryType.BINARY) + + val collations = Set(utf8Lcase) + + val expectedCollatedMinMax = new StructType().add("a", StringType.STRING, true) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("a", StringType.STRING, true) + .add("b", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("a", StringType.STRING, true) + .add("b", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + .add( + StatsSchemaHelper.STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedMinMax, true) + .add(StatsSchemaHelper.MAX, expectedCollatedMinMax, true), + true), + true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } + + test("check getStatsSchema with collations - nested mix and multiple collations") { + val dataSchema = new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING) + .add("y", IntegerType.INTEGER) + .add("z", new StructType().add("p", StringType.STRING).add("q", DoubleType.DOUBLE))) + .add("arr", new ArrayType(StringType.STRING, true)) + .add("mp", new MapType(StringType.STRING, StringType.STRING, true)) + + val collations = Set(utf8Lcase, CollationIdentifier.SPARK_UTF8_BINARY) + + val expectedCollatedNested = new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("z", new StructType().add("p", StringType.STRING, true), true), + true) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("y", IntegerType.INTEGER, true) + .add( + "z", + new StructType() + .add("p", StringType.STRING, true) + .add("q", DoubleType.DOUBLE, true), + true), + true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("y", IntegerType.INTEGER, true) + .add( + "z", + new StructType() + .add("p", StringType.STRING, true) + .add("q", DoubleType.DOUBLE, true), + true), + true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("x", LongType.LONG, true) + .add("y", LongType.LONG, true) + .add( + "z", + new StructType() + .add("p", LongType.LONG, true) + .add("q", LongType.LONG, true), + true), + true) + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + .add( + StatsSchemaHelper.STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedNested, true) + .add(StatsSchemaHelper.MAX, expectedCollatedNested, true), + true) + .add( + CollationIdentifier.SPARK_UTF8_BINARY.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedNested, true) + .add(StatsSchemaHelper.MAX, expectedCollatedNested, true), + true), + true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } + + test("check getStatsSchema with collations - no eligible string columns") { + val dataSchema = new StructType() + .add("a", IntegerType.INTEGER) + .add("b", new ArrayType(StringType.STRING, true)) + .add("c", new MapType(StringType.STRING, LongType.LONG, true)) + + val collations = Set(utf8Lcase, unicode, CollationIdentifier.SPARK_UTF8_BINARY) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("a", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("a", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } +} \ No newline at end of file diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala deleted file mode 100644 index 27d2895f7f2..00000000000 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/DataSkippingUtilsSuite.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (2023) The Delta Lake Project Authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.delta.kernel.internal.util - -import scala.collection.JavaConverters._ - -import io.delta.kernel.expressions.Column -import io.delta.kernel.internal.skipping.DataSkippingUtils -import io.delta.kernel.types.{DataType, StructField, StructType} -import io.delta.kernel.types.IntegerType.INTEGER - -import org.scalatest.funsuite.AnyFunSuite - -class DataSkippingUtilsSuite extends AnyFunSuite { - - def col(name: String): Column = new Column(name) - - def nestedCol(name: String): Column = { - new Column(name.split("\\.")) - } - - /* For struct type checks for equality based on field names & data type only */ - def compareDataTypeUnordered(type1: DataType, type2: DataType): Boolean = (type1, type2) match { - case (schema1: StructType, schema2: StructType) => - val fields1 = schema1.fields().asScala.sortBy(_.getName) - val fields2 = schema2.fields().asScala.sortBy(_.getName) - if (fields1.length != fields2.length) { - false - } else { - fields1.zip(fields2).forall { case (field1: StructField, field2: StructField) => - field1.getName == field2.getName && - compareDataTypeUnordered(field1.getDataType, field2.getDataType) - } - } - case _ => - type1 == type2 - } - - def checkPruneStatsSchema( - inputSchema: StructType, - referencedCols: Set[Column], - expectedSchema: StructType): Unit = { - val prunedSchema = DataSkippingUtils.pruneStatsSchema(inputSchema, referencedCols.asJava) - assert( - compareDataTypeUnordered(expectedSchema, prunedSchema), - s"expected=$expectedSchema\nfound=$prunedSchema") - } - - test("pruneStatsSchema - multiple basic cases one level of nesting") { - val nestedField = new StructField( - "nested", - new StructType() - .add("col1", INTEGER) - .add("col2", INTEGER), - true) - val testSchema = new StructType() - .add(nestedField) - .add("top_level_col", INTEGER) - // no columns pruned - checkPruneStatsSchema( - testSchema, - Set(col("top_level_col"), nestedCol("nested.col1"), nestedCol("nested.col2")), - testSchema) - // top level column pruned - checkPruneStatsSchema( - testSchema, - Set(nestedCol("nested.col1"), nestedCol("nested.col2")), - new StructType().add(nestedField)) - // nested column only one field pruned - checkPruneStatsSchema( - testSchema, - Set(nestedCol("top_level_col"), nestedCol("nested.col1")), - new StructType() - .add("nested", new StructType().add("col1", INTEGER)) - .add("top_level_col", INTEGER)) - // nested column completely pruned - checkPruneStatsSchema( - testSchema, - Set(nestedCol("top_level_col")), - new StructType().add("top_level_col", INTEGER)) - // prune all columns - checkPruneStatsSchema( - testSchema, - Set(), - new StructType()) - } - - test("pruneStatsSchema - 3 levels of nesting") { - /* - |--level1: struct - | |--level2: struct - | |--level3: struct - | |--level_4_col: int - | |--level_3_col: int - | |--level_2_col: int - */ - val testSchema = new StructType() - .add( - "level1", - new StructType() - .add( - "level2", - new StructType() - .add( - "level3", - new StructType().add("level_4_col", INTEGER)) - .add("level_3_col", INTEGER)) - .add("level_2_col", INTEGER)) - // prune only 4th level col - checkPruneStatsSchema( - testSchema, - Set(nestedCol("level1.level2.level_3_col"), nestedCol("level1.level_2_col")), - new StructType() - .add( - "level1", - new StructType() - .add("level2", new StructType().add("level_3_col", INTEGER)) - .add("level_2_col", INTEGER))) - // prune only 3rd level column - checkPruneStatsSchema( - testSchema, - Set(nestedCol("level1.level2.level3.level_4_col"), nestedCol("level1.level_2_col")), - new StructType() - .add( - "level1", - new StructType() - .add( - "level2", - new StructType() - .add( - "level3", - new StructType().add("level_4_col", INTEGER))) - .add("level_2_col", INTEGER))) - // prune 4th and 3rd level column - checkPruneStatsSchema( - testSchema, - Set(nestedCol("level1.level_2_col")), - new StructType() - .add( - "level1", - new StructType() - .add("level_2_col", INTEGER))) - // prune all columns - checkPruneStatsSchema( - testSchema, - Set(), - new StructType()) - } -} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala index 430d0bdd178..e106d5a8c8b 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/PartitionUtilsSuite.scala @@ -27,6 +27,9 @@ import io.delta.kernel.types._ import org.scalatest.funsuite.AnyFunSuite class PartitionUtilsSuite extends AnyFunSuite { + private val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + private val unicode = CollationIdentifier.fromString("ICU.UNICODE") + // Table schema // Data columns: data1: int, data2: string, date3: struct(data31: boolean, data32: long) // Partition columns: part1: int, part2: date, part3: string @@ -58,27 +61,76 @@ class PartitionUtilsSuite extends AnyFunSuite { // single predicate on a data column predicate("=", col("data1"), ofInt(12)) -> ("ALWAYS_TRUE()", "(column(`data1`) = 12)"), + // single predicate with default collation on a data column + predicate("=", col("data2"), ofString("12"), CollationIdentifier.SPARK_UTF8_BINARY) -> + ("ALWAYS_TRUE()", "(column(`data2`) = 12 COLLATE SPARK.UTF8_BINARY)"), + // single predicate with non-default collation on a data column + predicate("=", col("data2"), ofString("12"), utf8Lcase) -> + ("ALWAYS_TRUE()", "(column(`data2`) = 12 COLLATE SPARK.UTF8_LCASE)"), + predicate("=", col("data2"), ofString("12"), unicode) -> + ("ALWAYS_TRUE()", "(column(`data2`) = 12 COLLATE ICU.UNICODE)"), // multiple predicates on data columns joined with AND predicate( "AND", predicate("=", col("data1"), ofInt(12)), predicate(">=", col("data2"), ofString("sss"))) -> ("ALWAYS_TRUE()", "((column(`data1`) = 12) AND (column(`data2`) >= sss))"), + // multiple predicates with collation on data columns joined with AND + predicate( + "AND", + predicate("=", col("data2"), ofString("12")), + predicate(">=", col("data2"), ofString("sss"), utf8Lcase)) -> + ( + "ALWAYS_TRUE()", + "((column(`data2`) = 12) AND (column(`data2`) >= sss COLLATE SPARK.UTF8_LCASE))"), + // multiple predicates with collation on data columns joined with AND + predicate( + "AND", + predicate("=", col("data2"), ofString("12"), utf8Lcase), + predicate(">=", col("data2"), ofString("sss"), unicode)) -> + ( + "ALWAYS_TRUE()", + """((column(`data2`) = 12 COLLATE SPARK.UTF8_LCASE) AND + |(column(`data2`) >= sss COLLATE ICU.UNICODE))""".stripMargin.replaceAll("\n", " ")), // multiple predicates on data columns joined with OR predicate( "OR", predicate("<=", col("data2"), ofString("sss")), predicate("=", col("data3", "data31"), ofBoolean(true))) -> ("ALWAYS_TRUE()", "((column(`data2`) <= sss) OR (column(`data3`.`data31`) = true))"), + predicate( + "OR", + predicate("<=", col("data2"), ofString("sss"), utf8Lcase), + predicate("=", col("data3", "data31"), ofBoolean(true))) -> + ( + "ALWAYS_TRUE()", + "((column(`data2`) <= sss COLLATE SPARK.UTF8_LCASE) OR (column(`data3`.`data31`) = true))"), // single predicate on a partition column predicate("=", col("part1"), ofInt(12)) -> ("(column(`part1`) = 12)", "ALWAYS_TRUE()"), + // single predicate with default collation on partition column + predicate("=", col("part3"), ofString("12"), CollationIdentifier.SPARK_UTF8_BINARY) -> + ("(column(`part3`) = 12 COLLATE SPARK.UTF8_BINARY)", "ALWAYS_TRUE()"), + // single predicate with non-default collation on partition column + predicate("=", col("part3"), ofString("12"), utf8Lcase) -> + ("(column(`part3`) = 12 COLLATE SPARK.UTF8_LCASE)", "ALWAYS_TRUE()"), + predicate("=", col("part3"), ofString("12"), unicode) -> + ("(column(`part3`) = 12 COLLATE ICU.UNICODE)", "ALWAYS_TRUE()"), // multiple predicates on partition columns joined with AND predicate( "AND", predicate("=", col("part1"), ofInt(12)), predicate(">=", col("part3"), ofString("sss"))) -> ("((column(`part1`) = 12) AND (column(`part3`) >= sss))", "ALWAYS_TRUE()"), + // multiple predicates with collation on partition columns joined with AND + predicate( + "AND", + predicate("=", col("part3"), ofString("sss"), utf8Lcase), + predicate(">=", col("part3"), ofString("sss"), CollationIdentifier.SPARK_UTF8_BINARY)) -> + ( + """((column(`part3`) = sss COLLATE SPARK.UTF8_LCASE) AND (column(`part3`) + |>= sss COLLATE SPARK.UTF8_BINARY))""".stripMargin.replaceAll("\n", " "), + "ALWAYS_TRUE()"), // multiple predicates on partition columns joined with OR predicate( "OR", @@ -93,6 +145,15 @@ class PartitionUtilsSuite extends AnyFunSuite { predicate(">=", col("part3"), ofString("sss"))) -> ("(column(`part3`) >= sss)", "(column(`data1`) = 12)"), + // predicates with collation (each on data and partition column) joined with AND + predicate( + "AND", + predicate("=", col("data2"), ofString("12"), utf8Lcase), + predicate(">=", col("part3"), ofString("sss"), unicode)) -> + ( + "(column(`part3`) >= sss COLLATE ICU.UNICODE)", + "(column(`data2`) = 12 COLLATE SPARK.UTF8_LCASE)"), + // predicates (each on data and partition column) joined with OR predicate( "OR", @@ -100,6 +161,16 @@ class PartitionUtilsSuite extends AnyFunSuite { predicate(">=", col("part3"), ofString("sss"))) -> ("ALWAYS_TRUE()", "((column(`data1`) = 12) OR (column(`part3`) >= sss))"), + // predicates with collation (each on data and partition column) joined with OR + predicate( + "OR", + predicate("=", col("data2"), ofString("12"), unicode), + predicate(">=", col("part3"), ofString("sss"), unicode)) -> + ( + "ALWAYS_TRUE()", + """((column(`data2`) = 12 COLLATE ICU.UNICODE) OR (column(`part3`) + |>= sss COLLATE ICU.UNICODE))""".stripMargin.replaceAll("\n", " ")), + // predicates (multiple on data and partition columns) joined with AND predicate( "AND", @@ -155,9 +226,22 @@ class PartitionUtilsSuite extends AnyFunSuite { "(column(`part3`) >= sss)", "(column(`data1`) = column(`part1`))"), + // predicates with collation (data and partitions compared in the same expression) + predicate( + "AND", + predicate("=", col("data2"), col("part3"), utf8Lcase), + predicate(">=", col("part3"), ofString("sss"), unicode)) -> + ( + "(column(`part3`) >= sss COLLATE ICU.UNICODE)", + "(column(`data2`) = column(`part3`) COLLATE SPARK.UTF8_LCASE)"), + // predicate only on data column but reverse order of literal and column predicate("=", ofInt(12), col("data1")) -> - ("ALWAYS_TRUE()", "(12 = column(`data1`))")) + ("ALWAYS_TRUE()", "(12 = column(`data1`))"), + + // predicate with collation only on data column but reverse order of literal and column + predicate("=", ofString("12"), col("data2"), utf8Lcase) -> + ("ALWAYS_TRUE()", "(12 = column(`data2`) COLLATE SPARK.UTF8_LCASE)")) partitionTestCases.foreach { case (predicate, (partitionPredicate, dataPredicate)) => @@ -179,6 +263,14 @@ class PartitionUtilsSuite extends AnyFunSuite { // exp predicate for checkpoint reader pushdown "(column(`add`.`partitionValues_parsed`.`part2`) = 12)"), + // single predicate with collation on a partition column + predicate("=", col("part3"), ofString("sss"), utf8Lcase) -> + ( + // exp predicate for partition pruning + "(ELEMENT_AT(column(`add`.`partitionValues`), part3) = sss COLLATE SPARK.UTF8_LCASE)", + + // exp predicate for checkpoint reader pushdown + "(column(`add`.`partitionValues_parsed`.`part3`) = sss COLLATE SPARK.UTF8_LCASE)"), // multiple predicates on partition columns joined with AND predicate( "AND", @@ -194,6 +286,21 @@ class PartitionUtilsSuite extends AnyFunSuite { """((column(`add`.`partitionValues_parsed`.`part1`) = 12) AND |(column(`add`.`partitionValues_parsed`.`part3`) >= sss))""" .stripMargin.replaceAll("\n", " ")), + // multiple predicates with collation on partition columns joined with AND + predicate( + "AND", + predicate("=", col("part3"), ofString("sss"), utf8Lcase), + predicate(">=", col("part3"), ofString("sss"), CollationIdentifier.SPARK_UTF8_BINARY)) -> + ( + // exp predicate for partition pruning + """((ELEMENT_AT(column(`add`.`partitionValues`), part3) = sss COLLATE SPARK.UTF8_LCASE) AND + |(ELEMENT_AT(column(`add`.`partitionValues`), part3) >= sss COLLATE SPARK.UTF8_BINARY))""" + .stripMargin.replaceAll("\n", " "), + + // exp predicate for checkpoint reader pushdown + """((column(`add`.`partitionValues_parsed`.`part3`) = sss COLLATE SPARK.UTF8_LCASE) AND + |(column(`add`.`partitionValues_parsed`.`part3`) >= sss COLLATE SPARK.UTF8_BINARY))""" + .stripMargin.replaceAll("\n", " ")), // multiple predicates on partition columns joined with OR predicate( "OR", @@ -310,4 +417,12 @@ class PartitionUtilsSuite extends AnyFunSuite { private def predicate(name: String, children: Expression*): Predicate = { new Predicate(name, children.asJava) } + + private def predicate( + name: String, + left: Expression, + right: Expression, + collationIdentifier: CollationIdentifier) = { + new Predicate(name, left, right, collationIdentifier) + } } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala new file mode 100644 index 00000000000..3b8316bbe7a --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/test/TestUtils.scala @@ -0,0 +1,37 @@ +package io.delta.kernel.test + +import io.delta.kernel.expressions.{Column, Literal} + +import java.util.Optional + +/** Utility functions for tests. */ +trait TestUtils { + def col(name: String): Column = new Column(name) + + def nestedCol(name: String): Column = { + new Column(name.split("\\.")) + } + + def literal(value: Any): Literal = { + value match { + case v: String => Literal.ofString(v) + case v: Int => Literal.ofInt(v) + case v: Long => Literal.ofLong(v) + case v: Float => Literal.ofFloat(v) + case v: Double => Literal.ofDouble(v) + case v: Boolean => Literal.ofBoolean(v) + case _ => throw new IllegalArgumentException(s"Unsupported literal type: ${value}") + } + } + + protected def optionToJava[T](option: Option[T]): Optional[T] = { + option match { + case Some(value) => Optional.of(value) + case None => Optional.empty() + } + } + + protected def optionalToScala[T](optional: Optional[T]): Option[T] = { + if (optional.isPresent) Some(optional.get()) else None + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java index 6fd66fbe114..f2758b7ef5c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionUtils.java @@ -31,7 +31,6 @@ import java.nio.charset.StandardCharsets; import java.util.Comparator; import java.util.List; -import java.util.Optional; import java.util.function.Function; import java.util.function.IntPredicate; import java.util.stream.Collectors; @@ -515,16 +514,6 @@ static void checkIsLiteral(Expression expr, Expression parentExpr, String errorM } } - /** Creates a {@link Predicate} with name, children and optional collation. */ - static Predicate createPredicate( - String name, List children, Optional collationIdentifier) { - if (collationIdentifier.isPresent()) { - return new Predicate(name, children, collationIdentifier.get()); - } else { - return new Predicate(name, children); - } - } - /** * Checks if the collation is `UTF8_BINARY`, since this is the only collation the default engine * can evaluate. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index fcf94b44359..468dd70225d 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -15,9 +15,9 @@ */ package io.delta.kernel.defaults.internal.expressions; -import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.createPredicate; import static io.delta.kernel.expressions.AlwaysFalse.ALWAYS_FALSE; import static io.delta.kernel.expressions.AlwaysTrue.ALWAYS_TRUE; +import static io.delta.kernel.internal.util.ExpressionUtils.createPredicate; import static java.util.stream.Collectors.joining; import io.delta.kernel.expressions.*; diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java index bfb752ce2bd..868469aa9ae 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/StartsWithExpressionEvaluator.java @@ -16,6 +16,7 @@ package io.delta.kernel.defaults.internal.expressions; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*; +import static io.delta.kernel.internal.util.ExpressionUtils.createPredicate; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.expressions.Expression; diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index e0503f2d958..405df4c8306 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -20,12 +20,10 @@ import java.sql.Date import java.time.{Instant, OffsetDateTime} import java.time.temporal.ChronoUnit import java.util.Optional - import scala.collection.JavaConverters._ - import io.delta.golden.GoldenTableUtils.goldenTablePath import io.delta.kernel.{Scan, Snapshot, Table} -import io.delta.kernel.data.{ColumnarBatch, ColumnVector, FilteredColumnarBatch, Row} +import io.delta.kernel.data.{ColumnVector, ColumnarBatch, FilteredColumnarBatch, Row} import io.delta.kernel.defaults.engine.{DefaultEngine, DefaultJsonHandler, DefaultParquetHandler} import io.delta.kernel.defaults.engine.hadoopio.HadoopFileIO import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch @@ -33,6 +31,7 @@ import io.delta.kernel.defaults.internal.data.vector.{DefaultGenericVector, Defa import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestUtils, WriteUtils} import io.delta.kernel.engine.{Engine, JsonHandler, ParquetHandler} import io.delta.kernel.engine.FileReadResult +import io.delta.kernel.exceptions.KernelEngineException import io.delta.kernel.expressions._ import io.delta.kernel.expressions.Literal._ import io.delta.kernel.internal.{InternalScanFileUtils, ScanImpl, TableConfig} @@ -42,15 +41,15 @@ import io.delta.kernel.types.IntegerType.INTEGER import io.delta.kernel.types.StringType.STRING import io.delta.kernel.utils.{CloseableIterator, FileStatus} import io.delta.kernel.utils.CloseableIterable.emptyIterable - import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog} - import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.{Row => SparkRow} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.types.{IntegerType => SparkIntegerType, StructField => SparkStructField, StructType => SparkStructType} import org.scalatest.funsuite.AnyFunSuite +import java.util + class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with SQLHelper with WriteUtils { @@ -1587,6 +1586,477 @@ class ScanSuite extends AnyFunSuite with TestUtils } } + test("partition pruning - predicates with SPARK.UTF8_BINARY on partition column") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "b"), ("c", "d"), ("e", "f")).toDF("p", "d") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + if (createCheckpoint) { + // Create a checkpoint for the table + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val filterToFileNumber = Map( + new Predicate( + "<", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new Predicate( + "=", + ofString("d"), + col("p"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new Predicate( + "=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new Predicate( + ">=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate( + ">", + col("p"), + ofString("e"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new And( + new Predicate( + ">=", + col("p"), + ofString("b"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("p"), + ofString("e"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate( + "=", + col("p"), + ofString("x"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + col("p"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + ">=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("p"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> totalFiles, + new Predicate( + "STARTS_WITH", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new In( + col("p"), + java.util.Arrays.asList(ofString("a"), ofString("x")), + CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new In( + col("p"), + java.util.Arrays.asList(ofString("x"), ofString("y")), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0) + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("data skipping - predicates with SPARK.UTF8_BINARY on data column") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + // Create three files with values on non-partitioned STRING columns (c1, c2) + // Files: ("a","x"), ("c","y"), ("e","z") + Seq(("a", "x")).toDF("c1", "c2").write.format("delta") + .save(tempDir.getCanonicalPath) + Seq(("c", "y")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + if (createCheckpoint) { + // Create a checkpoint for the table + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val filterToFileNumber = Map( + new Predicate( + "<", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new Predicate( + "=", + ofString("d"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new Predicate( + "=", + ofString("a"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new Predicate( + "<=", + ofString("a"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate( + "<", + ofString("e"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new And( + new Predicate( + ">=", + col("c1"), + ofString("b"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("e"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate( + "=", + col("c1"), + ofString("x"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + col("c1"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + ">=", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> totalFiles, + new And( + new Predicate( + "<=", + ofString("b"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">=", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + ">=", + ofString("c"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new Or( + new Predicate( + "=", + ofString("a"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "=", + ofString("z"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate( + "<", + ofString("d"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new And( + new Predicate( + "<", + ofString("e"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 0, + ) + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("data skipping - collated predicates not or partially convertible to skipping filter") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x")).toDF("c1", "c2").write.format("delta") + .save(tempDir.getCanonicalPath) + Seq(("c", "y")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1") + + val filterToFileNumber = Map( + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("a"), + utf8Lcase) -> totalFiles, + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("z"), + unicode) -> totalFiles, + new In( + col("c1"), + java.util.Arrays.asList(ofString("a"), ofString("z")), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new In( + col("c2"), + java.util.Arrays.asList(ofString("x"), ofString("zz")), + utf8Lcase) -> totalFiles, + new And( + new Predicate( + "<", + col("c1"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY), + new In( + col("c2"), + java.util.Arrays.asList(ofString("x"), ofString("zz")), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate("STARTS_WITH", col("c1"), ofString("a"), utf8Lcase), + new In(col("c2"), util.Arrays.asList(ofString("x"), ofString("y")), unicode)) -> totalFiles, + new And( + new Predicate("STARTS_WITH", col("c1"), ofString("a"), CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("STARTS_WITH", col("c2"), ofString("x"), unicode)) -> totalFiles + ) + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("data skipping - evaluation fails with non default collation on data column") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x")).toDF("c1", "c2").write.format("delta") + .save(tempDir.getCanonicalPath) + Seq(("c", "y")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1") + + val failingPredicates = Seq( + new Predicate("<", col("c1"), ofString("a"), utf8Lcase), + new Predicate("=", ofString("d"), col("c1"), unicode), + new And( + new Predicate(">=", col("c1"), ofString("b"), utf8Lcase), + new Predicate("<=", col("c1"), ofString("e"), unicode)), + new Or( + new Predicate("<", col("c1"), ofString("b"), utf8Lcase), + new Predicate(">", col("c1"), ofString("a"), CollationIdentifier.SPARK_UTF8_BINARY)), + new And( + new Predicate(">=", col("c1"), ofString("a"), utf8Lcase), + new Predicate("<=", col("c1"), ofString("z"), unicode)), + new Predicate("=", col("c1"), ofString("a"), utf8Lcase) + ) + + failingPredicates.foreach { predicate => + val ex = intercept[KernelEngineException] { + collectScanFileRows(snapshot.getScanBuilder().withFilter(predicate).build()) + } + assert(ex.getMessage.contains("Unsupported collation")) + assert(ex.getMessage.contains(CollationIdentifier.SPARK_UTF8_BINARY.toString)) + assert(ex.getCause.isInstanceOf[UnsupportedOperationException]) + } + } + } + } + + test("partition and data skipping - combined pruning on partition and data columns") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x", "u")).toDF("p", "c1", "c2") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + Seq(("c", "y", "v")).toDF("p", "c1", "c2") + .write + .format("delta") + .mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z", "w")).toDF("p", "c1", "c2") + .write + .format("delta") + .mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val filterToFileNumber: Map[Predicate, Int] = Map( + // Both partition (p >= 'b') and data (c1 <= 'y') prune => 1 match ('c','y','v') + new And( + new Predicate( + "<=", + col("c1"), + ofString("y"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">=", + col("p"), + ofString("b"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + "<=", + col("p"), + ofString("c"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new And( + new Predicate( + ">=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<", + col("c1"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 0 + ) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("partition pruning - predicates with non default collation on partition column") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "b"), ("c", "d"), ("e", "f")).toDF("p", "d") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + + if (createCheckpoint) { + // Create a checkpoint for the table + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1") + + // Non-default collations are not supported by the default engine for predicate evaluation. + // Assert that attempting to evaluate such predicates during partition pruning throws. + val failingPredicates = Seq( + new Predicate("<", col("p"), ofString("a"), utf8Lcase), + new Predicate("=", ofString("d"), col("p"), unicode), + new And( + new Predicate(">=", col("p"), ofString("b"), utf8Lcase), + new Predicate("<=", col("p"), ofString("e"), unicode)), + new Or( + new Predicate("<", col("p"), ofString("b"), utf8Lcase), + new Predicate(">", col("p"), ofString("a"), CollationIdentifier.SPARK_UTF8_BINARY)), + new Predicate("STARTS_WITH", col("p"), ofString("a"), utf8Lcase), + new In( + col("p"), + java.util.Arrays.asList(ofString("a"), ofString("c")), + utf8Lcase), + new Or( + new In(col("p"), java.util.Arrays.asList(ofString("x"), ofString("y")), unicode), + new Predicate("=", col("p"), ofString("z")))) + + failingPredicates.foreach { predicate => + val ex = intercept[KernelEngineException] { + collectScanFileRows(snapshot.getScanBuilder().withFilter(predicate).build()) + } + assert(ex.getMessage.contains("Unsupported collation")) + assert(ex.getMessage.contains(CollationIdentifier.SPARK_UTF8_BINARY.toString)) + assert(ex.getCause.isInstanceOf[UnsupportedOperationException]) + } + } + } + } + Seq( "spark-variant-checkpoint", "spark-variant-stable-feature-checkpoint", diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala index 2247f97bf89..4cd88a5a62f 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/internal/expressions/ExpressionSuiteBase.scala @@ -19,10 +19,10 @@ import scala.collection.JavaConverters._ import io.delta.kernel.data.{ColumnarBatch, ColumnVector} import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch -import io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.createPredicate import io.delta.kernel.defaults.utils.{DefaultVectorTestUtils, TestUtils} import io.delta.kernel.defaults.utils.DefaultKernelTestUtils.getValueAsObject import io.delta.kernel.expressions._ +import io.delta.kernel.internal.util.ExpressionUtils.createPredicate import io.delta.kernel.types._ trait ExpressionSuiteBase extends TestUtils with DefaultVectorTestUtils {