diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java index 131eea0f5de..7316d9f8132 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java @@ -197,7 +197,7 @@ static CloseableIterator transformLogicalData( } ColumnarBatch data = filteredBatch.getData(); - if (!data.getSchema().equals(tableSchema)) { + if (!data.getSchema().equivalentIgnoreCollations(tableSchema)) { throw dataSchemaMismatch(tablePath, tableSchema, data.getSchema()); } for (String partitionColName : partitionColNames) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java index 0251cb28d5e..45921ad15f7 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/ScanImpl.java @@ -366,7 +366,8 @@ private CloseableIterator applyDataSkipping( // pruning it after is much simpler StructType prunedStatsSchema = DataSkippingUtils.pruneStatsSchema( - getStatsSchema(metadata.getDataSchema()), dataSkippingFilter.getReferencedCols()); + getStatsSchema(metadata.getDataSchema(), dataSkippingFilter.getReferencedCollations()), + dataSkippingFilter.getReferencedCols()); // Skipping happens in two steps: // 1. The predicate produces false for any file whose stats prove we can safely skip it. A diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java index 96d2b67d791..894d825dacc 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingPredicate.java @@ -27,12 +27,6 @@ public class DataSkippingPredicate extends Predicate { /** Set of {@link Column}s referenced by the predicate or any of its child expressions */ private final Set referencedCols; - /** - * Set of {@link CollationIdentifier}s referenced by this predicate or any of its child - * expressions - */ - private final Set collationIdentifiers; - /** * @param name the predicate name * @param children list of expressions that are input to this predicate. @@ -42,7 +36,6 @@ public class DataSkippingPredicate extends Predicate { DataSkippingPredicate(String name, List children, Set referencedCols) { super(name, children); this.referencedCols = Collections.unmodifiableSet(referencedCols); - this.collationIdentifiers = Collections.unmodifiableSet(new HashSet<>()); } /** @@ -59,7 +52,6 @@ public class DataSkippingPredicate extends Predicate { Set referencedCols) { super(name, children, collationIdentifier); this.referencedCols = Collections.unmodifiableSet(referencedCols); - this.collationIdentifiers = Collections.singleton(collationIdentifier); } /** @@ -73,8 +65,6 @@ public class DataSkippingPredicate extends Predicate { DataSkippingPredicate(String name, DataSkippingPredicate left, DataSkippingPredicate right) { super(name, Arrays.asList(left, right)); this.referencedCols = immutableUnion(left.referencedCols, right.referencedCols); - this.collationIdentifiers = - immutableUnion(left.collationIdentifiers, right.collationIdentifiers); } /** @return set of columns referenced by this predicate or any of its child expressions */ @@ -87,7 +77,26 @@ public Set getReferencedCols() { * expressions */ public Set getReferencedCollations() { - return collationIdentifiers; + Set referencedCollations = new HashSet<>(); + + if (this.getCollationIdentifier().isPresent()) { + referencedCollations.add(this.getCollationIdentifier().get()); + } + + for (Expression child : children) { + if (child instanceof Predicate) { + if (child instanceof DataSkippingPredicate) { + referencedCollations.addAll(((DataSkippingPredicate) child).getReferencedCollations()); + } else { + throw new IllegalStateException( + String.format( + "Expected child Predicate of DataSkippingPredicate to also be a" + + " DataSkippingPredicate, but found %s", + child.getClass().getName())); + } + } + } + return Collections.unmodifiableSet(referencedCollations); } /** @return an unmodifiable set containing all elements from both sets. */ diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java index d038d312668..5f23424a841 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/StatsSchemaHelper.java @@ -91,6 +91,7 @@ public static boolean isSkippingEligibleDataType(DataType dataType, boolean isCo * |-- a: struct (nullable = true) * | |-- b: struct (nullable = true) * | | |-- c: long (nullable = true) + * | | |-- d: string (nullable = true) * * *

Collected Statistics: @@ -102,18 +103,32 @@ public static boolean isSkippingEligibleDataType(DataType dataType, boolean isCo * | | |-- a: struct (nullable = false) * | | | |-- b: struct (nullable = false) * | | | | |-- c: long (nullable = true) + * | | | | |-- d: string (nullable = true) * | |-- maxValues: struct (nullable = false) * | | |-- a: struct (nullable = false) * | | | |-- b: struct (nullable = false) * | | | | |-- c: long (nullable = true) + * | | | | |-- d: string (nullable = true) * | |-- nullCount: struct (nullable = false) * | | |-- a: struct (nullable = false) * | | | |-- b: struct (nullable = false) * | | | | |-- c: long (nullable = true) + * | | | | |-- d: string (nullable = true) + * | |-- statsWithCollation: struct (nullable = true) + * | | |-- collationName: struct (nullable = true) + * | | | |-- min: struct (nullable = false) + * | | | | |-- a: struct (nullable = false) + * | | | | | |-- b: struct (nullable = false) + * | | | | | | |-- d: string (nullable = true) + * | | | |-- max: struct (nullable = false) + * | | | | |-- a: struct (nullable = false) + * | | | | | |-- b: struct (nullable = false) + * | | | | | | |-- d: string (nullable = true) * | |-- tightBounds: boolean (nullable = true) * */ - public static StructType getStatsSchema(StructType dataSchema) { + public static StructType getStatsSchema( + StructType dataSchema, Set collationIdentifiers) { StructType statsSchema = new StructType().add(NUM_RECORDS, LongType.LONG, true); StructType minMaxStatsSchema = getMinMaxStatsSchema(dataSchema); @@ -128,6 +143,11 @@ public static StructType getStatsSchema(StructType dataSchema) { statsSchema = statsSchema.add(TIGHT_BOUNDS, BooleanType.BOOLEAN, true); + StructType collatedMinMaxStatsSchema = getCollatedStatsSchema(dataSchema, collationIdentifiers); + if (collatedMinMaxStatsSchema.length() > 0) { + statsSchema = statsSchema.add(STATS_WITH_COLLATION, collatedMinMaxStatsSchema, true); + } + return statsSchema; } @@ -272,24 +292,60 @@ public boolean isSkippingEligibleNullCountColumn(Column column) { /** * Given a data schema returns the expected schema for a min or max statistics column. This means * 1) replace logical names with physical names 2) set nullable=true 3) only keep stats eligible - * fields (i.e. don't include fields with isSkippingEligibleDataType=false) + * fields (i.e. don't include fields with isSkippingEligibleDataType=false). Collation-aware + * statistics are not included. */ private static StructType getMinMaxStatsSchema(StructType dataSchema) { + return getMinMaxStatsSchema(dataSchema, /* isCollatedSkipping */ false); + } + + /** + * Given a data schema returns the expected schema for a min or max statistics column. This means + * 1) replace logical names with physical names 2) set nullable=true 3) only keep stats eligible + * fields (i.e. don't include fields with isSkippingEligibleDataType=false). In case when + * isCollatedSkipping is true, only `StringType` fields are eligible. + */ + private static StructType getMinMaxStatsSchema( + StructType dataSchema, boolean isCollatedSkipping) { List fields = new ArrayList<>(); for (StructField field : dataSchema.fields()) { - if (isSkippingEligibleDataType(field.getDataType(), false)) { + if (isSkippingEligibleDataType(field.getDataType(), isCollatedSkipping)) { fields.add(new StructField(getPhysicalName(field), field.getDataType(), true)); } else if (field.getDataType() instanceof StructType) { fields.add( new StructField( getPhysicalName(field), - getMinMaxStatsSchema((StructType) field.getDataType()), + getMinMaxStatsSchema((StructType) field.getDataType(), isCollatedSkipping), true)); } } return new StructType(fields); } + /** + * Given a data schema and a set of collation identifiers returns the expected schema for + * collation-aware statistics columns. This means 1) replace logical names with physical names 2) + * set nullable=true 3) only keep collated-stats eligible fields (`StringType` fields) + */ + private static StructType getCollatedStatsSchema( + StructType dataSchema, Set collationIdentifiers) { + StructType statsWithCollation = new StructType(); + StructType minMaxSchemaForCollationAwareFields = + getMinMaxStatsSchema(dataSchema, /* isCollatedSkipping */ true); + if (minMaxSchemaForCollationAwareFields.length() > 0) { + for (CollationIdentifier collationIdentifier : collationIdentifiers) { + statsWithCollation = + statsWithCollation.add( + collationIdentifier.toString(), + new StructType() + .add(MIN, minMaxSchemaForCollationAwareFields, true) + .add(MAX, minMaxSchemaForCollationAwareFields, true), + true); + } + } + return statsWithCollation; + } + /** * Given a data schema returns the expected schema for a null_count statistics column. This means * 1) replace logical names with physical names 2) set nullable=true 3) use LongType for all diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java b/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java index 97f8288c1b6..b3f9c669b78 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java @@ -380,7 +380,8 @@ private void validateLiteralType(StructField field, Literal literal) { // Variant stats in JSON are Z85 encoded strings, all other stats should match the field type DataType expectedLiteralType = field.getDataType() instanceof VariantType ? StringType.STRING : field.getDataType(); - if (literal.getDataType() == null || !literal.getDataType().equals(expectedLiteralType)) { + if (literal.getDataType() == null + || !literal.getDataType().equivalentIgnoreCollations(expectedLiteralType)) { throw DeltaErrors.statsTypeMismatch( field.getName(), expectedLiteralType, literal.getDataType()); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java index 173a85f4882..e5ca63ecab9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java @@ -55,6 +55,26 @@ public boolean equivalent(DataType dataType) { && ((ArrayType) dataType).getElementType().equivalent(getElementType()); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + if (this == dataType) { + return true; + } + if (dataType == null || getClass() != dataType.getClass()) { + return false; + } + ArrayType arrayType = (ArrayType) dataType; + return (elementField == null && arrayType.elementField == null) + || (elementField != null + && elementField.equivalentIgnoreCollations(arrayType.elementField)); + } + @Override public boolean isNested() { return true; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java index 6c10fd3fbb8..49411ae05b9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java @@ -36,6 +36,16 @@ public boolean equivalent(DataType dataType) { return equals(dataType); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + public boolean equivalentIgnoreCollations(DataType dataType) { + return equals(dataType); + } + /** * Returns true iff this data is a nested data type (it logically parameterized by other types). * diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java index b4d24fe51dc..3dbd45a8345 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java @@ -70,6 +70,27 @@ public boolean equivalent(DataType dataType) { && ((MapType) dataType).isValueContainsNull() == isValueContainsNull(); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + if (this == dataType) { + return true; + } + if (dataType == null || getClass() != dataType.getClass()) { + return false; + } + MapType mapType = (MapType) dataType; + return ((keyField == null && mapType.keyField == null) + || (keyField != null && keyField.equivalentIgnoreCollations(mapType.keyField))) + && ((valueField == null && mapType.valueField == null) + || (valueField != null && valueField.equivalentIgnoreCollations(mapType.valueField))); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java index d9dc911a1da..b46ad9e5c4d 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java @@ -52,6 +52,11 @@ public CollationIdentifier getCollationIdentifier() { return collationIdentifier; } + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + return dataType instanceof StringType; + } + @Override public boolean equals(Object o) { if (!(o instanceof StringType)) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java index b64d166b83e..06859de0c03 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java @@ -215,6 +215,21 @@ public boolean equals(Object o) { && Objects.equals(typeChanges, that.typeChanges); } + /** @return whether the struct fields are equal, ignoring collations */ + public boolean equivalentIgnoreCollations(StructField other) { + if (this == other) { + return true; + } + if (other == null) { + return false; + } + return nullable == other.nullable + && name.equals(other.name) + && dataType.equivalentIgnoreCollations(other.dataType) + && metadata.equals(other.metadata) + && Objects.equals(typeChanges, other.typeChanges); + } + @Override public int hashCode() { return Objects.hash(name, dataType, nullable, metadata, typeChanges); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java index 3a5a34f6be3..c524ae0e8f9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java @@ -175,6 +175,28 @@ public boolean equivalent(DataType dataType) { .allMatch(result -> result); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + if (this == dataType) { + return true; + } + if (dataType == null || getClass() != dataType.getClass()) { + return false; + } + StructType structType = (StructType) dataType; + return this.length() == structType.length() + && fieldNames.equals(structType.fieldNames) + && IntStream.range(0, this.length()) + .mapToObj(i -> this.at(i).equivalentIgnoreCollations(structType.at(i))) + .allMatch(result -> result); + } + @Override public boolean isNested() { return true; diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala index f405af8cdbc..b5d89d5df44 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/DataSkippingUtilsSuite.scala @@ -188,6 +188,119 @@ class DataSkippingUtilsSuite extends AnyFunSuite with TestUtils { new StructType()) } + test("pruneStatsSchema - collated min/max columns") { + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE") + val testSchema = new StructType() + .add( + MIN, + new StructType() + .add("s1", StringType.STRING) + .add("i1", INTEGER) + .add("i2", INTEGER) + .add("nested", new StructType().add("s2", StringType.STRING))) + .add( + MAX, + new StructType() + .add("s1", StringType.STRING) + .add("i1", INTEGER) + .add("i2", INTEGER) + .add("nested", new StructType().add("s2", StringType.STRING))) + .add( + STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add( + MIN, + new StructType() + .add("s1", StringType.STRING) + .add("nested", new StructType().add("s2", StringType.STRING))) + .add( + MAX, + new StructType() + .add("s1", StringType.STRING) + .add("nested", new StructType().add("s2", StringType.STRING)))) + .add( + unicode.toString, + new StructType() + .add( + MIN, + new StructType() + .add("s1", StringType.STRING) + .add("nested", new StructType().add("s2", StringType.STRING))) + .add( + MAX, + new StructType() + .add("s1", StringType.STRING) + .add("nested", new StructType().add("s2", StringType.STRING))))) + + val testCases = Seq( + ( + Set(nestedCol(s"$MIN.nested.s2"), nestedCol(s"$MAX.i1")), + new StructType() + .add( + MIN, + new StructType() + .add("i1", INTEGER) + .add("nested", new StructType().add("s2", StringType.STRING))) + .add( + MAX, + new StructType() + .add("i1", INTEGER))), + ( + Set( + collatedStatsCol(utf8Lcase, MIN, "s1"), + collatedStatsCol(unicode, MAX, "nested.s2")), + new StructType() + .add( + STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add( + MIN, + new StructType().add("s1", StringType.STRING))) + .add( + unicode.toString, + new StructType() + .add( + MAX, + new StructType().add( + "nested", + new StructType().add("s2", StringType.STRING)))))), + ( + Set( + nestedCol(s"$MIN.i2"), + collatedStatsCol(utf8Lcase, MAX, "nested.s2"), + collatedStatsCol(utf8Lcase, MIN, "nested.s2")), + new StructType() + .add( + MIN, + new StructType() + .add("i2", INTEGER)) + .add( + STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add( + MIN, + new StructType() + .add("nested", new StructType().add("s2", StringType.STRING))) + .add( + MAX, + new StructType() + .add("nested", new StructType().add("s2", StringType.STRING))))))) + + testCases.foreach { case (referencedCols, expectedSchema) => + checkPruneStatsSchema(testSchema, referencedCols, expectedSchema) + } + } + // TODO: add tests for remaining operators test("check constructDataSkippingFilter") { val testCases = Seq( diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala new file mode 100644 index 00000000000..4232ff54fda --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/skipping/StatsSchemaHelperSuite.scala @@ -0,0 +1,432 @@ +/* + * Copyright (2025) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.internal.skipping + +import scala.collection.JavaConverters.setAsJavaSetConverter + +import io.delta.kernel.types.{ArrayType, BinaryType, BooleanType, ByteType, CollationIdentifier, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} + +import org.scalatest.funsuite.AnyFunSuite + +class StatsSchemaHelperSuite extends AnyFunSuite { + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE") + + test("check getStatsSchema for supported data types") { + val testCases = Seq( + ( + new StructType().add("a", IntegerType.INTEGER), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("a", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("b", StringType.STRING), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("b", StringType.STRING, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("b", StringType.STRING, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("b", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("c", ByteType.BYTE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("c", ByteType.BYTE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("c", ByteType.BYTE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("c", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("d", ShortType.SHORT), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("d", ShortType.SHORT, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("d", ShortType.SHORT, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("d", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("e", LongType.LONG), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("e", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("f", FloatType.FLOAT), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("f", FloatType.FLOAT, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("f", FloatType.FLOAT, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("f", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("g", DoubleType.DOUBLE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("g", DoubleType.DOUBLE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("g", DoubleType.DOUBLE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("g", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("h", DateType.DATE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("h", DateType.DATE, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("h", DateType.DATE, true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("h", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("i", TimestampType.TIMESTAMP), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("i", TimestampType.TIMESTAMP, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("i", TimestampType.TIMESTAMP, true), + true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("i", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("j", TimestampNTZType.TIMESTAMP_NTZ, true), + true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("j", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType().add("k", new DecimalType(20, 5)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("k", new DecimalType(20, 5), true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("k", new DecimalType(20, 5), true), true) + .add(StatsSchemaHelper.NULL_COUNT, new StructType().add("k", LongType.LONG, true), true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true))) + + testCases.foreach { case (dataSchema, expectedStatsSchema) => + val statsSchema = StatsSchemaHelper.getStatsSchema( + dataSchema, + Set.empty[CollationIdentifier].asJava) + assert( + statsSchema == expectedStatsSchema, + s"Stats schema mismatch for data schema: $dataSchema") + } + } + + test("check getStatsSchema with mix of supported and unsupported data types") { + val testCases = Seq( + ( + new StructType() + .add("a", IntegerType.INTEGER) + .add("b", BinaryType.BINARY) + .add("c", new ArrayType(LongType.LONG, true)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add(StatsSchemaHelper.MIN, new StructType().add("a", IntegerType.INTEGER, true), true) + .add(StatsSchemaHelper.MAX, new StructType().add("a", IntegerType.INTEGER, true), true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + ( + new StructType() + .add( + "s", + new StructType() + .add("s1", StringType.STRING) + .add("s2", BooleanType.BOOLEAN)), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("s", new StructType().add("s1", StringType.STRING, true), true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("s", new StructType().add("s1", StringType.STRING, true), true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("s1", LongType.LONG, true) + .add("s2", LongType.LONG, true), + true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + // Un-nested array/map alongside a supported type + ( + new StructType() + .add("arr", new ArrayType(IntegerType.INTEGER, true)) + .add("mp", new MapType(StringType.STRING, LongType.LONG, true)) + .add("z", DoubleType.DOUBLE), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("z", DoubleType.DOUBLE, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("z", DoubleType.DOUBLE, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true) + .add("z", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true)), + // Nested array/map inside a struct; empty struct preserved in min/max + ( + new StructType() + .add( + "s", + new StructType() + .add("arr", new ArrayType(StringType.STRING, true)) + .add("mp", new MapType(IntegerType.INTEGER, StringType.STRING, true))) + .add("k", StringType.STRING), + new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("s", new StructType(), true) + .add("k", StringType.STRING, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("s", new StructType(), true) + .add("k", StringType.STRING, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true), + true) + .add("k", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true))) + + testCases.foreach { case (dataSchema, expectedStatsSchema) => + val statsSchema = StatsSchemaHelper.getStatsSchema( + dataSchema, + Set.empty[CollationIdentifier].asJava) + assert( + statsSchema == expectedStatsSchema, + s"Stats schema mismatch for data schema: $dataSchema") + } + } + + test("check getStatsSchema with collations - un-nested mix") { + val dataSchema = new StructType() + .add("a", StringType.STRING) + .add("b", IntegerType.INTEGER) + .add("c", BinaryType.BINARY) + + val collations = Set(utf8Lcase) + + val expectedCollatedMinMax = new StructType().add("a", StringType.STRING, true) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add("a", StringType.STRING, true) + .add("b", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add("a", StringType.STRING, true) + .add("b", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + .add( + StatsSchemaHelper.STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedMinMax, true) + .add(StatsSchemaHelper.MAX, expectedCollatedMinMax, true), + true), + true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } + + test("check getStatsSchema with collations - nested mix and multiple collations") { + val dataSchema = new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING) + .add("y", IntegerType.INTEGER) + .add("z", new StructType().add("p", StringType.STRING).add("q", DoubleType.DOUBLE))) + .add("arr", new ArrayType(StringType.STRING, true)) + .add("mp", new MapType(StringType.STRING, StringType.STRING, true)) + + val collations = Set(utf8Lcase, CollationIdentifier.SPARK_UTF8_BINARY) + + val expectedCollatedNested = new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("z", new StructType().add("p", StringType.STRING, true), true), + true) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("y", IntegerType.INTEGER, true) + .add( + "z", + new StructType() + .add("p", StringType.STRING, true) + .add("q", DoubleType.DOUBLE, true), + true), + true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType() + .add( + "s", + new StructType() + .add("x", StringType.STRING, true) + .add("y", IntegerType.INTEGER, true) + .add( + "z", + new StructType() + .add("p", StringType.STRING, true) + .add("q", DoubleType.DOUBLE, true), + true), + true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add( + "s", + new StructType() + .add("x", LongType.LONG, true) + .add("y", LongType.LONG, true) + .add( + "z", + new StructType() + .add("p", LongType.LONG, true) + .add("q", LongType.LONG, true), + true), + true) + .add("arr", LongType.LONG, true) + .add("mp", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + .add( + StatsSchemaHelper.STATS_WITH_COLLATION, + new StructType() + .add( + utf8Lcase.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedNested, true) + .add(StatsSchemaHelper.MAX, expectedCollatedNested, true), + true) + .add( + CollationIdentifier.SPARK_UTF8_BINARY.toString, + new StructType() + .add(StatsSchemaHelper.MIN, expectedCollatedNested, true) + .add(StatsSchemaHelper.MAX, expectedCollatedNested, true), + true), + true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } + + test("check getStatsSchema with collations - no eligible string columns") { + val dataSchema = new StructType() + .add("a", IntegerType.INTEGER) + .add("b", new ArrayType(StringType.STRING, true)) + .add("c", new MapType(StringType.STRING, LongType.LONG, true)) + + val collations = Set(utf8Lcase, unicode, CollationIdentifier.SPARK_UTF8_BINARY) + + val expectedStatsSchema = new StructType() + .add(StatsSchemaHelper.NUM_RECORDS, LongType.LONG, true) + .add( + StatsSchemaHelper.MIN, + new StructType().add("a", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.MAX, + new StructType().add("a", IntegerType.INTEGER, true), + true) + .add( + StatsSchemaHelper.NULL_COUNT, + new StructType() + .add("a", LongType.LONG, true) + .add("b", LongType.LONG, true) + .add("c", LongType.LONG, true), + true) + .add(StatsSchemaHelper.TIGHT_BOUNDS, BooleanType.BOOLEAN, true) + + val statsSchema = StatsSchemaHelper.getStatsSchema(dataSchema, collations.asJava) + assert(statsSchema == expectedStatsSchema) + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala index 41877204023..5d8f274282e 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala @@ -25,7 +25,10 @@ import scala.collection.immutable.Seq import io.delta.golden.GoldenTableUtils.goldenTablePath import io.delta.kernel._ import io.delta.kernel.Operation.{CREATE_TABLE, MANUAL_UPDATE, WRITE} -import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch, Row} +import io.delta.kernel.data.{ColumnarBatch, ColumnVector, FilteredColumnarBatch, Row} +import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch +import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector +import io.delta.kernel.defaults.internal.data.vector.DefaultStructVector import io.delta.kernel.defaults.internal.parquet.ParquetSuiteBase import io.delta.kernel.defaults.utils.{AbstractWriteUtils, TestRow, WriteUtils} import io.delta.kernel.engine.Engine @@ -35,11 +38,13 @@ import io.delta.kernel.expressions.Literal._ import io.delta.kernel.internal.{ScanImpl, SnapshotImpl, TableConfig} import io.delta.kernel.internal.checkpoints.CheckpointerSuite.selectSingleElement import io.delta.kernel.internal.table.SnapshotBuilderImpl +import io.delta.kernel.internal.types.DataTypeJsonSerDe import io.delta.kernel.internal.util.{Clock, JsonUtils} import io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames import io.delta.kernel.transaction.DataLayoutSpec import io.delta.kernel.types._ import io.delta.kernel.types.ByteType.BYTE +import io.delta.kernel.types.CollationIdentifier import io.delta.kernel.types.DateType.DATE import io.delta.kernel.types.DecimalType import io.delta.kernel.types.DoubleType.DOUBLE @@ -413,6 +418,212 @@ abstract class AbstractDeltaTableWritesSuite extends AnyFunSuite with AbstractWr } } + /////////////////////////////////////////////////////////////////////////// + // Collation write tests + /////////////////////////////////////////////////////////////////////////// + + test("insert into table - simple collated string column") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Binary = new StringType(CollationIdentifier.SPARK_UTF8_BINARY) + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val serbianWithVersion = new StringType("ICU.SR_CYRL_SRB.75.1") + val serbianWithoutVersion = new StringType("ICU.SR_CYRL_SRB") + + val commonSchema = new StructType() + .add("c1", IntegerType.INTEGER) + .add("c2", StringType.STRING) + .add("c3", utf8Binary) + .add("c4", utf8Lcase) + .add("c5", unicode) + val schemaWithVersion = commonSchema.add("c6", serbianWithVersion) + val schemaWithoutVersion = commonSchema.add("c6", serbianWithoutVersion) + + // First append + val data1 = + generateData(schemaWithVersion, Seq.empty, Map.empty, batchSize = 10, numBatches = 1) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schemaWithVersion, + data = Seq(Map.empty[String, Literal] -> data1)) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + // we use schemaWithoutVersion to verify since the version info is not stored in the + // schema serialization + verifyWrittenContent(tblPath, schemaWithoutVersion, data1.flatMap(_.toTestRows)) + + // Second append + val data2 = + generateData(schemaWithVersion, Seq.empty, Map.empty, batchSize = 5, numBatches = 1) + + val commitResult1 = appendData( + engine, + tblPath, + data = Seq(Map.empty[String, Literal] -> data2)) + verifyCommitResult(commitResult1, expVersion = 1, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 1, partitionCols = null) + verifyWrittenContent( + tblPath, + schemaWithoutVersion, + (data1 ++ data2).flatMap(_.toTestRows)) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString()) + assert(parsed === schemaWithoutVersion) + } + } + + test("insert into table - nested struct with collated string field") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Binary = new StringType(CollationIdentifier.SPARK_UTF8_BINARY) + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val nested = new StructType() + .add("c21", utf8Lcase) + .add("c22", IntegerType.INTEGER) + .add("c23", unicode) + .add("c24", utf8Binary) + val schema = new StructType() + .add("c1", LongType.LONG) + .add("c2", nested) + + val data = generateData(schema, Seq.empty, Map.empty, batchSize = 8, numBatches = 2) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schema, + data = Seq(Map.empty[String, Literal] -> data)) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + verifyWrittenContent(tblPath, schema, data.flatMap(_.toTestRows)) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString()) + assert(parsed === schema) + } + } + + test("insert into table - complex types with collated strings in nested fields") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Binary = new StringType(CollationIdentifier.SPARK_UTF8_BINARY) + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + + val nested = new StructType() + .add("c1", unicode) + .add("c2", IntegerType.INTEGER) + .add("c3", utf8Binary) + + val schema = new StructType() + .add("c1", IntegerType.INTEGER) + .add("c2", nested) + .add("c3", new ArrayType(utf8Lcase, true)) + .add("c4", new MapType(utf8Lcase, unicode, true)) + + // Build vectors + val batchSize = 5 + val c1Vector = testColumnVector(batchSize, IntegerType.INTEGER) + + val nestedVectors = Array[ColumnVector]( + testColumnVector(batchSize, unicode), + testColumnVector(batchSize, IntegerType.INTEGER), + testColumnVector(batchSize, utf8Binary)) + val c2Vector = new DefaultStructVector(batchSize, nested, Optional.empty(), nestedVectors) + + val c3Values: Seq[Seq[AnyRef]] = (0 until batchSize).map { i => + Seq(s"t$i", s"x$i").map(_.asInstanceOf[AnyRef]) + } + val c3Vector = buildArrayVector(c3Values, utf8Lcase, containsNull = true) + + val c4Type = new MapType(utf8Lcase, unicode, true) + val c4Values: Seq[Map[AnyRef, AnyRef]] = (0 until batchSize).map { i => + Map[AnyRef, AnyRef](s"k$i" -> s"v$i") + } + val c4Vector = buildMapVector(c4Values, c4Type) + + val vectors = Array[ColumnVector](c1Vector, c2Vector, c3Vector, c4Vector) + val batch = new DefaultColumnarBatch(batchSize, schema, vectors) + val fcb = new FilteredColumnarBatch(batch, Optional.empty()) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schema, + data = Seq(Map.empty[String, Literal] -> Seq(fcb))) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString()) + assert(parsed === schema) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + val expectedRows = Seq(fcb).flatMap(_.toTestRows) + verifyWrittenContent(tblPath, schema, expectedRows) + } + } + + test("stats: default engine writes binary stats for collated string columns") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Binary = new StringType(CollationIdentifier.SPARK_UTF8_BINARY) + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + + val schema = new StructType() + .add("c1", utf8Binary) + .add("c2", utf8Lcase) + .add("c3", unicode) + + val txn = getCreateTxn(engine, tblPath, schema) + commitTransaction(txn, engine, emptyIterable()) + + val batchSize = 4 + val values = Array("b", "A", "B", "a").map(_.asInstanceOf[AnyRef]) + val c1 = DefaultGenericVector.fromArray(utf8Binary, values) + val c2 = DefaultGenericVector.fromArray(utf8Lcase, values) + val c3 = DefaultGenericVector.fromArray(unicode, values) + val batch = new DefaultColumnarBatch(batchSize, schema, Array[ColumnVector](c1, c2, c3)) + val fcb = new FilteredColumnarBatch(batch, Optional.empty()) + + val commit = appendData(engine, tblPath, data = Seq(Map.empty[String, Literal] -> Seq(fcb))) + verifyCommitResult(commit, expVersion = 1, expIsReadyForCheckpoint = false) + + // Read stats + val snapshot = Table.forPath(engine, tblPath).getLatestSnapshot(engine) + val scan = snapshot.getScanBuilder().build() + val scanFiles = scan.asInstanceOf[ScanImpl].getScanFiles(engine, true).toSeq + .flatMap(_.getRows.toSeq) + val statsJson = scanFiles.headOption.flatMap { row => + val add = row.getStruct(row.getSchema.indexOf("add")) + val idx = add.getSchema.indexOf("stats") + if (idx >= 0 && !add.isNullAt(idx)) Some(add.getString(idx)) else None + }.getOrElse(fail("Stats JSON not found")) + + // Default engine computes just non-collated stats; verify min/max values + val mapper = JsonUtils.mapper() + val statsNode = mapper.readTree(statsJson) + val minValues = statsNode.get("minValues") + val maxValues = statsNode.get("maxValues") + + // All columns: [b, A, B, a] -> min "A", max "b" + assert(minValues.get("c1").asText() == "A") + assert(maxValues.get("c1").asText() == "b") + + assert(minValues.get("c2").asText() == "A") + assert(maxValues.get("c2").asText() == "b") + + assert(minValues.get("c3").asText() == "A") + assert(maxValues.get("c3").asText() == "b") + } + } + /////////////////////////////////////////////////////////////////////////// // Create table and insert data tests (CTAS & INSERT) /////////////////////////////////////////////////////////////////////////// diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index 65152015e80..015f798e394 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -19,6 +19,7 @@ import java.math.{BigDecimal => JBigDecimal} import java.sql.Date import java.time.{Instant, OffsetDateTime} import java.time.temporal.ChronoUnit +import java.util import java.util.Optional import scala.collection.JavaConverters._ @@ -1735,6 +1736,420 @@ class ScanSuite extends AnyFunSuite with TestUtils } } + test("data skipping - predicates with SPARK.UTF8_BINARY on data column") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + // Create three files with values on non-partitioned STRING columns (c1, c2) + // Files: ("a","x"), ("c","y"), ("e","z") + Seq(("a", "x")).toDF("c1", "c2").write.format("delta") + .save(tempDir.getCanonicalPath) + Seq(("c", "y")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder.build()).length + + if (createCheckpoint) { + // Create a checkpoint for the table + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val filterToFileNumber = Map( + new Predicate( + "<", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new Predicate( + "=", + ofString("d"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new Predicate( + "=", + ofString("a"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 1, + new Predicate( + "<=", + ofString("a"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate( + "<", + ofString("e"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY) -> 0, + new And( + new Predicate( + ">=", + col("c1"), + ofString("b"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("e"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate( + "=", + col("c1"), + ofString("x"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + col("c1"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + ">=", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> totalFiles, + new And( + new Predicate( + "<=", + ofString("b"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">=", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + ">=", + ofString("c"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new Or( + new Predicate( + "=", + ofString("a"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "=", + ofString("z"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate( + "<", + ofString("d"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new And( + new Predicate( + "<", + ofString("e"), + col("c1"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">", + ofString("y"), + col("c2"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 0) + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("data skipping - collated predicates not or partially convertible to skipping filter") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x")).toDF("c1", "c2").write.format("delta") + .save(tempDir.getCanonicalPath) + Seq(("c", "y")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1") + + val filterToFileNumber = Map( + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("a"), + utf8Lcase) -> totalFiles, + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("z"), + unicode) -> totalFiles, + new In( + col("c1"), + java.util.Arrays.asList(ofString("a"), ofString("z")), + CollationIdentifier.SPARK_UTF8_BINARY) -> totalFiles, + new In( + col("c2"), + java.util.Arrays.asList(ofString("x"), ofString("zz")), + utf8Lcase) -> totalFiles, + new And( + new Predicate( + "<", + col("c1"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY), + new In( + col("c2"), + java.util.Arrays.asList(ofString("x"), ofString("zz")), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new Or( + new Predicate("STARTS_WITH", col("c1"), ofString("a"), utf8Lcase), + new In( + col("c2"), + util.Arrays.asList(ofString("x"), ofString("y")), + unicode)) -> totalFiles, + new And( + new Predicate( + "STARTS_WITH", + col("c1"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate("STARTS_WITH", col("c2"), ofString("x"), unicode)) -> totalFiles) + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("data skipping - evaluation fails with non default collation on data column") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x")).toDF("c1", "c2").write.format("delta") + .save(tempDir.getCanonicalPath) + Seq(("c", "y")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z")).toDF("c1", "c2").write.format("delta").mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1") + + val failingPredicates = Seq( + new Predicate("<", col("c1"), ofString("a"), utf8Lcase), + new Predicate("=", ofString("d"), col("c1"), unicode), + new And( + new Predicate(">=", col("c1"), ofString("b"), utf8Lcase), + new Predicate("<=", col("c1"), ofString("e"), unicode)), + new Or( + new Predicate("<", col("c1"), ofString("b"), utf8Lcase), + new Predicate(">", col("c1"), ofString("a"), CollationIdentifier.SPARK_UTF8_BINARY)), + new And( + new Predicate(">=", col("c1"), ofString("a"), utf8Lcase), + new Predicate("<=", col("c1"), ofString("z"), unicode)), + new Predicate("=", col("c1"), ofString("a"), utf8Lcase)) + + failingPredicates.foreach { predicate => + val ex = intercept[KernelEngineException] { + collectScanFileRows(snapshot.getScanBuilder.withFilter(predicate).build()) + } + assert(ex.getMessage.contains("Unsupported collation")) + assert(ex.getMessage.contains(CollationIdentifier.SPARK_UTF8_BINARY.toString)) + assert(ex.getCause.isInstanceOf[UnsupportedOperationException]) + } + } + } + } + + test("partition and data skipping - combined pruning on partition and data columns") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x", "u")).toDF("p", "c1", "c2") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + Seq(("c", "y", "v")).toDF("p", "c1", "c2") + .write + .format("delta") + .mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z", "w")).toDF("p", "c1", "c2") + .write + .format("delta") + .mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + val totalFiles = collectScanFileRows(snapshot.getScanBuilder().build()).length + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val filterToFileNumber: Map[Predicate, Int] = Map( + new And( + new Predicate( + "<=", + col("c1"), + ofString("y"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + ">=", + col("p"), + ofString("b"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 1, + new And( + new Predicate( + "<=", + col("p"), + ofString("c"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 2, + new And( + new Predicate( + ">=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<", + col("c1"), + ofString("d"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> 0, + new And( + new Predicate( + ">=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)) -> totalFiles) + + checkSkipping(tempDir.getCanonicalPath, filterToFileNumber) + } + } + } + + test("partition and data skipping - evaluation fails with non default collation on " + + "combined filter") { + Seq(true, false).foreach { createCheckpoint => + withTempDir { tempDir => + Seq(("a", "x", "u")).toDF("p", "c1", "c2") + .write + .format("delta") + .partitionBy("p") + .save(tempDir.getCanonicalPath) + Seq(("c", "y", "v")).toDF("p", "c1", "c2") + .write + .format("delta") + .mode("append") + .save(tempDir.getCanonicalPath) + Seq(("e", "z", "w")).toDF("p", "c1", "c2") + .write + .format("delta") + .mode("append") + .save(tempDir.getCanonicalPath) + + val snapshot = latestSnapshot(tempDir.getCanonicalPath) + + if (createCheckpoint) { + val version = latestSnapshot(tempDir.getCanonicalPath).getVersion + Table.forPath(defaultEngine, tempDir.getCanonicalPath).checkpoint(defaultEngine, version) + } + + val utf8Lcase = CollationIdentifier.fromString("SPARK.UTF8_LCASE") + val unicode = CollationIdentifier.fromString("ICU.UNICODE.75.1") + + val failingPredicates = Seq( + new And( + new Predicate( + ">=", + col("p"), + ofString("a"), + CollationIdentifier.SPARK_UTF8_BINARY), + new Predicate( + "<=", + col("c1"), + ofString("z"), + utf8Lcase)), + new And( + new Predicate( + ">=", + col("p"), + ofString("a"), + unicode), + new Predicate( + "<=", + col("c1"), + ofString("z"), + CollationIdentifier.SPARK_UTF8_BINARY)), + new And( + new Predicate( + "<", + col("p"), + ofString("z"), + utf8Lcase), + new Predicate( + ">", + col("c1"), + ofString("a"), + unicode))) + + failingPredicates.foreach { predicate => + val ex = intercept[KernelEngineException] { + collectScanFileRows(snapshot.getScanBuilder.withFilter(predicate).build()) + } + assert(ex.getMessage.contains("Unsupported collation")) + assert(ex.getMessage.contains(CollationIdentifier.SPARK_UTF8_BINARY.toString)) + assert(ex.getCause.isInstanceOf[UnsupportedOperationException]) + } + } + } + } + Seq( "spark-variant-checkpoint", "spark-variant-stable-feature-checkpoint", diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index c7b9a35fd93..b50851ce3c8 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -626,6 +626,25 @@ trait AbstractTestUtils DefaultGenericVector.fromArray(dataType, mapValues.map(getMapValue).toArray) } + /** + * Builds an ArrayType ColumnVector from a sequence of per-row element sequences. + */ + def buildArrayVector( + valuesPerRow: Seq[Seq[AnyRef]], + elementType: DataType, + containsNull: Boolean): ColumnVector = { + val arrayType = new ArrayType(elementType, containsNull) + val arrayValues: Array[ArrayValue] = valuesPerRow.map { elems => + if (elems == null) null + else new ArrayValue { + override def getSize: Int = elems.size + override def getElements: ColumnVector = + DefaultGenericVector.fromArray(elementType, elems.toArray) + } + }.toArray + DefaultGenericVector.fromArray(arrayType, arrayValues.asInstanceOf[Array[AnyRef]]) + } + /** * Utility method to generate a [[dataType]] column vector of given size. * The nullability of rows is determined by the [[testIsNullValue(dataType, rowId)]]. @@ -699,7 +718,7 @@ trait AbstractTestUtils case LongType.LONG => rowId % 25 == 0 case FloatType.FLOAT => rowId % 5 == 0 case DoubleType.DOUBLE => rowId % 10 == 0 - case StringType.STRING => rowId % 2 == 0 + case _: StringType => rowId % 2 == 0 case BinaryType.BINARY => rowId % 3 == 0 case DateType.DATE => rowId % 5 == 0 case TimestampType.TIMESTAMP => rowId % 3 == 0 @@ -720,7 +739,7 @@ trait AbstractTestUtils case LongType.LONG => rowId * 287623L / 91 case FloatType.FLOAT => rowId * 7651.2323f / 91 case DoubleType.DOUBLE => rowId * 23423.23d / 17 - case StringType.STRING => (rowId % 19).toString + case _: StringType => (rowId % 19).toString case BinaryType.BINARY => Array[Byte]((rowId % 21).toByte, (rowId % 7 - 1).toByte) case DateType.DATE => (rowId * 28234) % 2876 case TimestampType.TIMESTAMP => (rowId * 2342342L) % 23 @@ -801,7 +820,7 @@ trait AbstractTestUtils case LongType.LONG => sparktypes.DataTypes.LongType case FloatType.FLOAT => sparktypes.DataTypes.FloatType case DoubleType.DOUBLE => sparktypes.DataTypes.DoubleType - case StringType.STRING => sparktypes.DataTypes.StringType + case _: StringType => sparktypes.DataTypes.StringType case BinaryType.BINARY => sparktypes.DataTypes.BinaryType case DateType.DATE => sparktypes.DataTypes.DateType case TimestampType.TIMESTAMP => sparktypes.DataTypes.TimestampType