diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java index 131eea0f5de..7316d9f8132 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Transaction.java @@ -197,7 +197,7 @@ static CloseableIterator transformLogicalData( } ColumnarBatch data = filteredBatch.getData(); - if (!data.getSchema().equals(tableSchema)) { + if (!data.getSchema().equivalentIgnoreCollations(tableSchema)) { throw dataSchemaMismatch(tablePath, tableSchema, data.getSchema()); } for (String partitionColName : partitionColNames) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java b/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java index 97f8288c1b6..b3f9c669b78 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/statistics/DataFileStatistics.java @@ -380,7 +380,8 @@ private void validateLiteralType(StructField field, Literal literal) { // Variant stats in JSON are Z85 encoded strings, all other stats should match the field type DataType expectedLiteralType = field.getDataType() instanceof VariantType ? StringType.STRING : field.getDataType(); - if (literal.getDataType() == null || !literal.getDataType().equals(expectedLiteralType)) { + if (literal.getDataType() == null + || !literal.getDataType().equivalentIgnoreCollations(expectedLiteralType)) { throw DeltaErrors.statsTypeMismatch( field.getName(), expectedLiteralType, literal.getDataType()); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java index 173a85f4882..e5ca63ecab9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/ArrayType.java @@ -55,6 +55,26 @@ public boolean equivalent(DataType dataType) { && ((ArrayType) dataType).getElementType().equivalent(getElementType()); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + if (this == dataType) { + return true; + } + if (dataType == null || getClass() != dataType.getClass()) { + return false; + } + ArrayType arrayType = (ArrayType) dataType; + return (elementField == null && arrayType.elementField == null) + || (elementField != null + && elementField.equivalentIgnoreCollations(arrayType.elementField)); + } + @Override public boolean isNested() { return true; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java index 6c10fd3fbb8..553a2894d83 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/DataType.java @@ -27,7 +27,7 @@ public abstract class DataType { /** - * Are the data types same? The metadata or column names could be different. + * Are the data types same? The metadata, collations or column names could be different. * * @param dataType * @return @@ -36,6 +36,16 @@ public boolean equivalent(DataType dataType) { return equals(dataType); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + public boolean equivalentIgnoreCollations(DataType dataType) { + return equals(dataType); + } + /** * Returns true iff this data is a nested data type (it logically parameterized by other types). * diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java index b4d24fe51dc..3dbd45a8345 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/MapType.java @@ -70,6 +70,27 @@ public boolean equivalent(DataType dataType) { && ((MapType) dataType).isValueContainsNull() == isValueContainsNull(); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + if (this == dataType) { + return true; + } + if (dataType == null || getClass() != dataType.getClass()) { + return false; + } + MapType mapType = (MapType) dataType; + return ((keyField == null && mapType.keyField == null) + || (keyField != null && keyField.equivalentIgnoreCollations(mapType.keyField))) + && ((valueField == null && mapType.valueField == null) + || (valueField != null && valueField.equivalentIgnoreCollations(mapType.valueField))); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java index d9dc911a1da..3415ad51e89 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StringType.java @@ -52,6 +52,27 @@ public CollationIdentifier getCollationIdentifier() { return collationIdentifier; } + /** + * Are the data types same? The metadata, collations or column names could be different. + * + * @param dataType + * @return + */ + public boolean equivalent(DataType dataType) { + return dataType instanceof StringType; + } + + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + return dataType instanceof StringType; + } + @Override public boolean equals(Object o) { if (!(o instanceof StringType)) { diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java index b64d166b83e..d55a00119c4 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructField.java @@ -215,6 +215,30 @@ public boolean equals(Object o) { && Objects.equals(typeChanges, that.typeChanges); } + /** @return whether the struct fields are equal, ignoring collations */ + public boolean equivalentIgnoreCollations(StructField other) { + if (this == other) { + return true; + } + if (other == null) { + return false; + } + // Compare metadata while ignoring collation metadata differences + FieldMetadata metadataWithoutCollations = + new FieldMetadata.Builder().fromMetadata(metadata).remove(COLLATIONS_METADATA_KEY).build(); + FieldMetadata otherMetadataWithoutCollations = + new FieldMetadata.Builder() + .fromMetadata(other.metadata) + .remove(COLLATIONS_METADATA_KEY) + .build(); + + return nullable == other.nullable + && name.equals(other.name) + && dataType.equivalentIgnoreCollations(other.dataType) + && metadataWithoutCollations.equals(otherMetadataWithoutCollations) + && Objects.equals(typeChanges, other.typeChanges); + } + @Override public int hashCode() { return Objects.hash(name, dataType, nullable, metadata, typeChanges); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java index 3a5a34f6be3..c524ae0e8f9 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/StructType.java @@ -175,6 +175,28 @@ public boolean equivalent(DataType dataType) { .allMatch(result -> result); } + /** + * Are the data types same? The collations could be different. + * + * @param dataType + * @return + */ + @Override + public boolean equivalentIgnoreCollations(DataType dataType) { + if (this == dataType) { + return true; + } + if (dataType == null || getClass() != dataType.getClass()) { + return false; + } + StructType structType = (StructType) dataType; + return this.length() == structType.length() + && fieldNames.equals(structType.fieldNames) + && IntStream.range(0, this.length()) + .mapToObj(i -> this.at(i).equivalentIgnoreCollations(structType.at(i))) + .allMatch(result -> result); + } + @Override public boolean isNested() { return true; diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala new file mode 100644 index 00000000000..6c2f5e1e1dc --- /dev/null +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/types/DataTypeSuite.scala @@ -0,0 +1,148 @@ +/* + * Copyright (2025) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.types + +import org.scalatest.funsuite.AnyFunSuite + +class DataTypeSuite extends AnyFunSuite { + + test("test equivalentIgnoreCollations") { + val utf8LcaseString = new StringType("SPARK.UTF8_LCASE") + val unicodeString = new StringType("ICU.UNICODE") + + val testCases = Seq( + (StringType.STRING, StringType.STRING, true), + (StringType.STRING, utf8LcaseString, true), + (IntegerType.INTEGER, StringType.STRING, false), + (utf8LcaseString, unicodeString, true), + ( + new ArrayType(StringType.STRING, true), + new ArrayType(utf8LcaseString, true), + true), + ( + new ArrayType(unicodeString, false), + new ArrayType(StringType.STRING, false), + true), + ( + new ArrayType(StringType.STRING, true), + new ArrayType(utf8LcaseString, false), + false), + ( + new MapType(StringType.STRING, IntegerType.INTEGER, false), + new MapType(utf8LcaseString, IntegerType.INTEGER, false), + true), + ( + new MapType(unicodeString, IntegerType.INTEGER, false), + new MapType(utf8LcaseString, IntegerType.INTEGER, false), + true), + ( + new MapType(unicodeString, IntegerType.INTEGER, false), + new MapType(utf8LcaseString, IntegerType.INTEGER, true), + false), + ( + new StructType() + .add("name", StringType.STRING) + .add("age", IntegerType.INTEGER), + new StructType() + .add("name", utf8LcaseString) + .add("age", IntegerType.INTEGER), + true), + ( + new StructType() + .add("name", StringType.STRING) + .add("details", new StructType().add("address", StringType.STRING)), + new StructType() + .add("name", unicodeString) + .add("details", new StructType().add("address", utf8LcaseString)), + true), + ( + new StructType() + .add("c1", new ArrayType(unicodeString, true)) + .add("c2", new MapType(StringType.STRING, utf8LcaseString, false)), + new StructType() + .add("c1", new ArrayType(StringType.STRING, true)) + .add("c2", new MapType(utf8LcaseString, unicodeString, false)), + true), + ( + new StructType() + .add("c1", new ArrayType(unicodeString, false)) + .add("c2", new MapType(StringType.STRING, utf8LcaseString, false)), + new StructType() + .add("c1", new ArrayType(StringType.STRING, true)) + .add("c2", new MapType(utf8LcaseString, unicodeString, false)), + false), + ( + new StructType() + .add("c1", new ArrayType(IntegerType.INTEGER, true)) + .add("c2", new MapType(StringType.STRING, utf8LcaseString, false)), + new StructType() + .add("c1", new ArrayType(StringType.STRING, true)) + .add("c2", new MapType(utf8LcaseString, unicodeString, false)), + false), + ( + new ArrayType( + new StructType().add("c1", new MapType(utf8LcaseString, StringType.STRING, true), true), + true), + new ArrayType( + new StructType().add("c1", new MapType(StringType.STRING, utf8LcaseString, true), true), + true), + true), + ( + new ArrayType( + new StructType().add("c1", new MapType(utf8LcaseString, StringType.STRING, true), true), + true), + new ArrayType( + new StructType().add("c2", new MapType(StringType.STRING, utf8LcaseString, true), true), + true), + false), + ( + new ArrayType( + new StructType().add("c1", new MapType(utf8LcaseString, StringType.STRING, true), false), + true), + new ArrayType( + new StructType().add("c1", new MapType(StringType.STRING, utf8LcaseString, true), true), + true), + false), + ( + new MapType( + new StructType().add("c1", utf8LcaseString), + new ArrayType(utf8LcaseString, false), + true), + new MapType( + new StructType().add("c1", StringType.STRING), + new ArrayType(utf8LcaseString, false), + true), + true), + ( + new MapType( + new StructType().add("c1", utf8LcaseString), + new ArrayType(utf8LcaseString, false), + false), + new MapType( + new StructType().add("c1", StringType.STRING), + new ArrayType(utf8LcaseString, false), + true), + false), + ( + new MapType(new StructType().add("c1", utf8LcaseString), StringType.STRING, false), + new MapType(new StructType().add("c1", StringType.STRING), utf8LcaseString, true), + false)) + + testCases.foreach { case (dt1, dt2, expected) => + assert(dt1.equivalentIgnoreCollations(dt2) == expected) + } + } +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala index 4f419a39258..093df59e440 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala @@ -25,7 +25,10 @@ import scala.collection.immutable.Seq import io.delta.golden.GoldenTableUtils.goldenTablePath import io.delta.kernel._ import io.delta.kernel.Operation.{CREATE_TABLE, MANUAL_UPDATE, WRITE} -import io.delta.kernel.data.{ColumnarBatch, FilteredColumnarBatch, Row} +import io.delta.kernel.data.{ColumnarBatch, ColumnVector, FilteredColumnarBatch, Row} +import io.delta.kernel.defaults.internal.data.DefaultColumnarBatch +import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector +import io.delta.kernel.defaults.internal.data.vector.DefaultStructVector import io.delta.kernel.defaults.internal.parquet.ParquetSuiteBase import io.delta.kernel.defaults.utils.{AbstractWriteUtils, TestRow, WriteUtils} import io.delta.kernel.engine.Engine @@ -35,11 +38,13 @@ import io.delta.kernel.expressions.Literal._ import io.delta.kernel.internal.{ScanImpl, SnapshotImpl, TableConfig} import io.delta.kernel.internal.checkpoints.CheckpointerSuite.selectSingleElement import io.delta.kernel.internal.table.SnapshotBuilderImpl +import io.delta.kernel.internal.types.DataTypeJsonSerDe import io.delta.kernel.internal.util.{Clock, JsonUtils} import io.delta.kernel.internal.util.SchemaUtils.casePreservingPartitionColNames import io.delta.kernel.transaction.DataLayoutSpec import io.delta.kernel.types._ import io.delta.kernel.types.ByteType.BYTE +import io.delta.kernel.types.CollationIdentifier import io.delta.kernel.types.DateType.DATE import io.delta.kernel.types.DecimalType import io.delta.kernel.types.DoubleType.DOUBLE @@ -413,6 +418,580 @@ abstract class AbstractDeltaTableWritesSuite extends AnyFunSuite with AbstractWr } } + /////////////////////////////////////////////////////////////////////////// + // Collation write tests + /////////////////////////////////////////////////////////////////////////// + + test("insert into table - simple collated string column") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val serbianWithVersion = new StringType("ICU.SR_CYRL_SRB.75.1") + val serbianWithoutVersion = new StringType("ICU.SR_CYRL_SRB") + + val commonSchema = new StructType() + .add("c1", IntegerType.INTEGER) + .add("c2", StringType.STRING) + .add("c3", STRING) + .add("c4", utf8Lcase) + .add("c5", unicode) + val schemaWithVersion = commonSchema.add("c6", serbianWithVersion) + val schemaWithoutVersion = commonSchema.add("c6", serbianWithoutVersion) + + // First append + val data1 = + generateData(schemaWithVersion, Seq.empty, Map.empty, batchSize = 10, numBatches = 1) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schemaWithVersion, + data = Seq(Map.empty[String, Literal] -> data1)) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + // we use schemaWithoutVersion to verify since the version info is not stored in the + // schema serialization + verifyWrittenContent(tblPath, schemaWithoutVersion, data1.flatMap(_.toTestRows)) + + // Second append + val data2 = + generateData(schemaWithVersion, Seq.empty, Map.empty, batchSize = 5, numBatches = 1) + + val commitResult1 = appendData( + engine, + tblPath, + data = Seq(Map.empty[String, Literal] -> data2)) + verifyCommitResult(commitResult1, expVersion = 1, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 1, partitionCols = null) + verifyWrittenContent( + tblPath, + schemaWithoutVersion, + (data1 ++ data2).flatMap(_.toTestRows)) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString()) + assert(parsed === schemaWithoutVersion) + } + } + + test("insert into table - complex types with collated strings in nested/array/map") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val unicodeWithVersion = new StringType("ICU.UNICODE.74") + + val commonNested = new StructType() + .add("s1", utf8Lcase) + .add("n", INTEGER) + + val nestedWithVersion = commonNested.add("s2", unicodeWithVersion) + val nestedWithoutVersion = commonNested.add("s2", unicode) + + val schemaWithVersion = new StructType() + .add("nested", nestedWithVersion) + .add("arr", new ArrayType(utf8Lcase, true)) + .add("map", new MapType(utf8Lcase, unicode, true)) + val schemaWithoutVersion = new StructType() + .add("nested", nestedWithoutVersion) + .add("arr", new ArrayType(utf8Lcase, true)) + .add("map", new MapType(utf8Lcase, unicode, true)) + + val batchSize = 4 + + def buildBatch(seed: String): FilteredColumnarBatch = { + val nestedVectors = Array[ColumnVector]( + testColumnVector(batchSize, utf8Lcase), + testColumnVector(batchSize, INTEGER), + testColumnVector(batchSize, unicode)) + val nestedVector = new DefaultStructVector( + batchSize, + nestedWithVersion, + Optional.empty(), + nestedVectors) + + val arrValues: Seq[Seq[AnyRef]] = (0 until batchSize).map { i => + Seq(s"${seed}t$i", s"${seed}x$i").map(_.asInstanceOf[AnyRef]) + } + val arrVector = buildArrayVector(arrValues, utf8Lcase, containsNull = true) + + val mapType = new MapType(utf8Lcase, unicode, true) + val mapValues: Seq[Map[AnyRef, AnyRef]] = (0 until batchSize).map { i => + Map[AnyRef, AnyRef](s"${seed}k$i" -> s"${seed}v$i") + } + val mapVector = buildMapVector(mapValues, mapType) + + val vectors = Array[ColumnVector](nestedVector, arrVector, mapVector) + val batch = new DefaultColumnarBatch(batchSize, schemaWithVersion, vectors) + new FilteredColumnarBatch(batch, Optional.empty()) + } + + val fcb1 = buildBatch("a-") + val fcb2 = buildBatch("b-") + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schemaWithVersion, + data = Seq(Map.empty[String, Literal] -> Seq(fcb1))) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + + val commitResult1 = appendData( + engine, + tblPath, + data = Seq(Map.empty[String, Literal] -> Seq(fcb2))) + + verifyCommitResult(commitResult1, expVersion = 1, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 1, partitionCols = null) + + val expectedRows = Seq(fcb1, fcb2).flatMap(_.toTestRows) + verifyWrittenContent(tblPath, schemaWithVersion, expectedRows) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString()) + assert(parsed === schemaWithoutVersion) + } + } + + test("insert into table - nested struct with collated string field") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val nested = new StructType() + .add("c21", utf8Lcase) + .add("c22", IntegerType.INTEGER) + .add("c23", unicode) + .add("c24", STRING) + val schema = new StructType() + .add("c1", LongType.LONG) + .add("c2", nested) + + val data = generateData(schema, Seq.empty, Map.empty, batchSize = 8, numBatches = 2) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schema, + data = Seq(Map.empty[String, Literal] -> data)) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + verifyWrittenContent(tblPath, schema, data.flatMap(_.toTestRows)) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString()) + assert(parsed === schema) + } + } + + test("insert into table - complex types with collated strings in nested fields") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + + val nested = new StructType() + .add("c1", unicode) + .add("c2", IntegerType.INTEGER) + .add("c3", STRING) + + val schema = new StructType() + .add("c1", IntegerType.INTEGER) + .add("c2", nested) + .add("c3", new ArrayType(utf8Lcase, true)) + .add("c4", new MapType(utf8Lcase, unicode, true)) + + // Build vectors + val batchSize = 5 + val c1Vector = testColumnVector(batchSize, IntegerType.INTEGER) + + val nestedVectors = Array[ColumnVector]( + testColumnVector(batchSize, unicode), + testColumnVector(batchSize, IntegerType.INTEGER), + testColumnVector(batchSize, STRING)) + val c2Vector = new DefaultStructVector(batchSize, nested, Optional.empty(), nestedVectors) + + val c3Values: Seq[Seq[AnyRef]] = (0 until batchSize).map { i => + Seq(s"t$i", s"x$i").map(_.asInstanceOf[AnyRef]) + } + val c3Vector = buildArrayVector(c3Values, utf8Lcase, containsNull = true) + + val c4Type = new MapType(utf8Lcase, unicode, true) + val c4Values: Seq[Map[AnyRef, AnyRef]] = (0 until batchSize).map { i => + Map[AnyRef, AnyRef](s"k$i" -> s"v$i") + } + val c4Vector = buildMapVector(c4Values, c4Type) + + val vectors = Array[ColumnVector](c1Vector, c2Vector, c3Vector, c4Vector) + val batch = new DefaultColumnarBatch(batchSize, schema, vectors) + val fcb = new FilteredColumnarBatch(batch, Optional.empty()) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schema, + data = Seq(Map.empty[String, Literal] -> Seq(fcb))) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString) + assert(parsed === schema) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 0) + val expectedRows = Seq(fcb).flatMap(_.toTestRows) + verifyWrittenContent(tblPath, schema, expectedRows) + } + } + + test("insert into partitioned table - collated string partition columns") { + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE.75.1") + val serbian = new StringType("ICU.SR_CYRL_SRB") + Seq( + // (p1BatchType, p2BatchType, vBatchType) + (utf8Lcase, unicode, serbian), + (serbian, serbian, utf8Lcase), + (utf8Lcase, serbian, unicode), + (unicode, serbian, STRING), + (STRING, serbian, STRING), + (STRING, STRING, STRING), + (utf8Lcase, STRING, utf8Lcase)).foreach { case (p1BatchType, p2BatchType, vBatchType) => + withTempDirAndEngine { (tblPath, engine) => + val schema = new StructType() + .add("id", INTEGER) + .add("p1", utf8Lcase) // partition column + .add("p2", unicode) // partition column + .add("v", serbian) + + val schemaWithoutVersion = new StructType() + .add("id", INTEGER) + .add("p1", utf8Lcase) + .add("p2", new StringType("ICU.UNICODE")) + .add("v", serbian) + + val dataSchema = new StructType() + .add("id", INTEGER) + .add("p1", p1BatchType) + .add("p2", p2BatchType) + .add("v", vBatchType) + + val partCols = Seq("p1", "p2") + + val v0Part = Map("p1" -> ofString("a"), "p2" -> ofString("alpha")) + val v0Data = generateData(dataSchema, partCols, v0Part, batchSize = 8, numBatches = 1) + + val v1Part = Map("p1" -> ofString("B"), "p2" -> ofString("beta")) + val v1Data = generateData(dataSchema, partCols, v1Part, batchSize = 5, numBatches = 1) + + val commitResult0 = appendData( + engine, + tblPath, + isNewTable = true, + schema, + partCols, + data = Seq(v0Part -> v0Data, v1Part -> v1Data)) + + verifyCommitResult(commitResult0, expVersion = 0, expIsReadyForCheckpoint = false) + // Expect partition columns in the same case as in the schema + verifyCommitInfo(tblPath, version = 0, partitionCols = partCols) + + val expectedRows0 = v0Data.flatMap(_.toTestRows) ++ v1Data.flatMap(_.toTestRows) + verifyWrittenContent(tblPath, schema, expectedRows0) + + val v2Part = Map("p1" -> ofString("c"), "p2" -> ofString("gamma")) + val v2Data = generateData(dataSchema, partCols, v2Part, batchSize = 4, numBatches = 3) + + val commitResult1 = appendData( + engine, + tblPath, + data = Seq(v2Part -> v2Data)) + + verifyCommitResult(commitResult1, expVersion = 1, expIsReadyForCheckpoint = false) + // For subsequent commits, partitionBy is not recorded in commit info + verifyCommitInfo(tblPath, version = 1, partitionCols = null) + + val expectedRows1 = expectedRows0 ++ v2Data.flatMap(_.toTestRows) + verifyWrittenContent(tblPath, schema, expectedRows1) + + val metadata = getMetadata(engine, tblPath) + val parsed = DataTypeJsonSerDe.deserializeStructType(metadata.getSchemaString) + assert(parsed === schemaWithoutVersion) + } + } + } + + test("stats: default engine writes binary stats for collated string columns") { + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val serbian = new StringType("ICU.SR_CYRL_SRB.74") + Seq( + (STRING, utf8Lcase, unicode), + (serbian, serbian, serbian), + (STRING, serbian, unicode), + (STRING, STRING, STRING)).foreach { case (c1DataType, c2DataType, c3DataType) => + withTempDirAndEngine { (tblPath, engine) => + val schema = new StructType() + .add("c1", STRING) + .add("c2", utf8Lcase) + .add("c3", unicode) + + val txn = getCreateTxn(engine, tblPath, schema) + commitTransaction(txn, engine, emptyIterable()) + + val batchSize = 4 + val values = Array("b", "A", "B", "a").map(_.asInstanceOf[AnyRef]) + val c1 = DefaultGenericVector.fromArray(c1DataType, values) + val c2 = DefaultGenericVector.fromArray(c2DataType, values) + val c3 = DefaultGenericVector.fromArray(c3DataType, values) + val batch = new DefaultColumnarBatch(batchSize, schema, Array[ColumnVector](c1, c2, c3)) + val fcb = new FilteredColumnarBatch(batch, Optional.empty()) + + val commit = appendData(engine, tblPath, data = Seq(Map.empty[String, Literal] -> Seq(fcb))) + verifyCommitResult(commit, expVersion = 1, expIsReadyForCheckpoint = false) + + // Read stats JSON + val snapshot = Table.forPath(engine, tblPath).getLatestSnapshot(engine) + val scan = snapshot.getScanBuilder().build() + val scanFiles = scan.asInstanceOf[ScanImpl].getScanFiles(engine, true).toSeq + .flatMap(_.getRows.toSeq) + val statsJson = scanFiles.headOption.flatMap { row => + val add = row.getStruct(row.getSchema.indexOf("add")) + val idx = add.getSchema.indexOf("stats") + if (idx >= 0 && !add.isNullAt(idx)) Some(add.getString(idx)) else None + }.getOrElse(fail("Stats JSON not found")) + + // Default engine computes just non-collated stats; verify min/max values + val mapper = JsonUtils.mapper() + val statsNode = mapper.readTree(statsJson) + val minValues = statsNode.get("minValues") + val maxValues = statsNode.get("maxValues") + + // All columns: [b, A, B, a] -> min "A", max "b" + assert(minValues.get("c1").asText() == "A") + assert(maxValues.get("c1").asText() == "b") + + assert(minValues.get("c2").asText() == "A") + assert(maxValues.get("c2").asText() == "b") + + assert(minValues.get("c3").asText() == "A") + assert(maxValues.get("c3").asText() == "b") + } + } + } + + test("stats: collated non-partition column in partitioned table") { + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val unicode = new StringType("ICU.UNICODE") + val serbian = new StringType("ICU.SR_CYRL_SRB.74") + Seq( + (utf8Lcase, unicode), + (serbian, serbian), + (utf8Lcase, serbian), + (STRING, STRING), + (STRING, utf8Lcase), + (unicode, STRING)).foreach { case (pBatchType, dBatchType) => + withTempDirAndEngine { (tblPath, engine) => + val schema = new StructType() + .add("p", utf8Lcase) // partition column + .add("c", serbian) // non-partition, collated + + val txn = getCreateTxn(engine, tblPath, schema, partCols = Seq("p")) + commitTransaction(txn, engine, emptyIterable()) + + // Commit 1: p = "north", c values [b, A, B, a] + val batchSize1 = 4 + val cValues1 = Array("b", "A", "B", "a").map(_.asInstanceOf[AnyRef]) + val pValues1 = Array.fill[AnyRef](batchSize1)("north") + val pVec1 = DefaultGenericVector.fromArray(pBatchType, pValues1) + val cVec1 = DefaultGenericVector.fromArray(dBatchType, cValues1) + val batch1 = new DefaultColumnarBatch(batchSize1, schema, Array[ColumnVector](pVec1, cVec1)) + val fcb1 = new FilteredColumnarBatch(batch1, Optional.empty()) + + val commit1 = + appendData(engine, tblPath, data = Seq(Map("p" -> ofString("north")) -> Seq(fcb1))) + verifyCommitResult(commit1, expVersion = 1, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 1, partitionCols = null) + + // Commit 2: p = "south", c values [d, C] + val batchSize2 = 2 + val cValues2 = Array("d", "C", "a").map(_.asInstanceOf[AnyRef]) + val pValues2 = Array.fill[AnyRef](batchSize2)("south") + val pVec2 = DefaultGenericVector.fromArray(pBatchType, pValues2) + val cVec2 = DefaultGenericVector.fromArray(dBatchType, cValues2) + val batch2 = new DefaultColumnarBatch(batchSize2, schema, Array[ColumnVector](pVec2, cVec2)) + val fcb2 = new FilteredColumnarBatch(batch2, Optional.empty()) + + val commit2 = + appendData(engine, tblPath, data = Seq(Map("p" -> ofString("south")) -> Seq(fcb2))) + verifyCommitResult(commit2, expVersion = 2, expIsReadyForCheckpoint = false) + verifyCommitInfo(tblPath, version = 2, partitionCols = null) + + // Read stats JSON + val snapshot = Table.forPath(engine, tblPath).getLatestSnapshot(engine) + val scan = snapshot.getScanBuilder.build() + val scanFiles = scan.asInstanceOf[ScanImpl].getScanFiles(engine, true).toSeq + .flatMap(_.getRows.toSeq) + + val mapper = JsonUtils.mapper() + assert(scanFiles.nonEmpty) + scanFiles.foreach { row => + val add = row.getStruct(row.getSchema.indexOf("add")) + val path = add.getString(add.getSchema.indexOf("path")) + val statsIdx = add.getSchema.indexOf("stats") + assert(statsIdx >= 0 && !add.isNullAt(statsIdx)) + val statsJson = add.getString(statsIdx) + val statsNode = mapper.readTree(statsJson) + val minValues = statsNode.get("minValues") + val maxValues = statsNode.get("maxValues") + + val minC = minValues.get("c").asText() + val maxC = maxValues.get("c").asText() + + if (path.contains("p=north")) { + // For [b, A, B, a] -> min "A", max "b" + assert(minC == "A") + assert(maxC == "b") + } else if (path.contains("p=south")) { + // For [d, C] -> min "C", max "d" + assert(minC == "C") + assert(maxC == "d") + } else { + fail(s"Unexpected partition: $path") + } + } + } + } + } + + test("stats: collect min/max for collated nested struct fields") { + withTempDirAndEngine { (tblPath, engine) => + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val nested = new StructType() + .add("s1", utf8Lcase) + .add("i1", INTEGER) + val schema = new StructType() + .add("nested", nested) + + val txn = getCreateTxn(engine, tblPath, schema) + commitTransaction(txn, engine, emptyIterable()) + + val batchSize = 4 + val s1Values = Array("b", "A", "B", "a").map(_.asInstanceOf[AnyRef]) + val i1Values = Array[java.lang.Integer](3, -1, 10, 5) + val s1 = DefaultGenericVector.fromArray(utf8Lcase, s1Values) + val i1 = DefaultGenericVector.fromArray(INTEGER, i1Values.asInstanceOf[Array[AnyRef]]) + val nestedVector = new DefaultStructVector( + batchSize, + nested, + Optional.empty(), + Array[ColumnVector](s1, i1)) + val batch = new DefaultColumnarBatch( + batchSize, + schema, + Array[ColumnVector](nestedVector)) + val fcb = new FilteredColumnarBatch(batch, Optional.empty()) + + val commit = appendData(engine, tblPath, data = Seq(Map.empty[String, Literal] -> Seq(fcb))) + verifyCommitResult(commit, expVersion = 1, expIsReadyForCheckpoint = false) + + // Read stats JSON + val snapshot = Table.forPath(engine, tblPath).getLatestSnapshot(engine) + val scan = snapshot.getScanBuilder().build() + val scanFiles = scan.asInstanceOf[ScanImpl].getScanFiles(engine, true).toSeq + .flatMap(_.getRows.toSeq) + val statsJson = scanFiles.headOption.flatMap { row => + val add = row.getStruct(row.getSchema.indexOf("add")) + val idx = add.getSchema.indexOf("stats") + if (idx >= 0 && !add.isNullAt(idx)) Some(add.getString(idx)) else None + }.getOrElse(fail("Stats JSON not found")) + + val mapper = JsonUtils.mapper() + val statsNode = mapper.readTree(statsJson) + val minValues = statsNode.get("minValues") + val maxValues = statsNode.get("maxValues") + + val minNested = minValues.get("nested") + val maxNested = maxValues.get("nested") + + // For s1: [b, A, B, a] -> min "A", max "b" + assert(minNested.get("s1").asText() == "A") + assert(maxNested.get("s1").asText() == "b") + + // For i1: [3, -1, 10, 5] -> min -1, max 10 + assert(minNested.get("i1").asInt() == -1) + assert(maxNested.get("i1").asInt() == 10) + } + } + + test("stats: arrays and maps produce no stats; collated string field stats present") { + withTempDirAndEngine { (tblPath, engine) => + val unicode = new StringType("ICU.UNICODE") + val utf8Lcase = new StringType("SPARK.UTF8_LCASE") + val schema = new StructType() + .add("name", unicode) + .add("arr", new ArrayType(utf8Lcase, true)) + .add("map", new MapType(unicode, INTEGER, true)) + + val txn = getCreateTxn(engine, tblPath, schema) + commitTransaction(txn, engine, emptyIterable()) + + val batchSize = 4 + val nameValues = Array("b", "A", "B", "a").map(_.asInstanceOf[AnyRef]) + val nameVec = DefaultGenericVector.fromArray(unicode, nameValues) + + val arrValues: Seq[Seq[AnyRef]] = (0 until batchSize).map { i => + Seq(s"x$i").map(_.asInstanceOf[AnyRef]) + } + val arrVec = buildArrayVector(arrValues, utf8Lcase, containsNull = true) + + val mapType = new MapType(unicode, INTEGER, true) + val mapValues: Seq[Map[AnyRef, AnyRef]] = (0 until batchSize).map { i => + Map[AnyRef, AnyRef](s"k$i" -> java.lang.Integer.valueOf(i)) + } + val mapVec = buildMapVector(mapValues, mapType) + + val batch = new DefaultColumnarBatch( + batchSize, + schema, + Array[ColumnVector](nameVec, arrVec, mapVec)) + val fcb = new FilteredColumnarBatch(batch, Optional.empty()) + + val commit = appendData(engine, tblPath, data = Seq(Map.empty[String, Literal] -> Seq(fcb))) + verifyCommitResult(commit, expVersion = 1, expIsReadyForCheckpoint = false) + + // Read stats JSON + val snapshot = Table.forPath(engine, tblPath).getLatestSnapshot(engine) + val scan = snapshot.getScanBuilder().build() + val scanFiles = scan.asInstanceOf[ScanImpl].getScanFiles(engine, true).toSeq + .flatMap(_.getRows.toSeq) + val statsJson = scanFiles.headOption.flatMap { row => + val add = row.getStruct(row.getSchema.indexOf("add")) + val idx = add.getSchema.indexOf("stats") + if (idx >= 0 && !add.isNullAt(idx)) Some(add.getString(idx)) else None + }.getOrElse(fail("Stats JSON not found")) + + val mapper = JsonUtils.mapper() + val statsNode = mapper.readTree(statsJson) + val minValues = statsNode.get("minValues") + val maxValues = statsNode.get("maxValues") + + // String column stats are present + assert(minValues.get("name").asText() == "A") + assert(maxValues.get("name").asText() == "b") + + // Array/Map columns should not have stats + assert(!minValues.has("arr")) + assert(!maxValues.has("arr")) + assert(!minValues.has("map")) + assert(!maxValues.has("map")) + } + } + /////////////////////////////////////////////////////////////////////////// // Create table and insert data tests (CTAS & INSERT) /////////////////////////////////////////////////////////////////////////// diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index 8d3601a8fbd..ecaa98d4424 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -626,6 +626,25 @@ trait AbstractTestUtils DefaultGenericVector.fromArray(dataType, mapValues.map(getMapValue).toArray) } + /** + * Builds an ArrayType ColumnVector from a sequence of per-row element sequences. + */ + def buildArrayVector( + valuesPerRow: Seq[Seq[AnyRef]], + elementType: DataType, + containsNull: Boolean): ColumnVector = { + val arrayType = new ArrayType(elementType, containsNull) + val arrayValues: Array[ArrayValue] = valuesPerRow.map { elems => + if (elems == null) null + else new ArrayValue { + override def getSize: Int = elems.size + override def getElements: ColumnVector = + DefaultGenericVector.fromArray(elementType, elems.toArray) + } + }.toArray + DefaultGenericVector.fromArray(arrayType, arrayValues.asInstanceOf[Array[AnyRef]]) + } + /** * Utility method to generate a [[dataType]] column vector of given size. * The nullability of rows is determined by the [[testIsNullValue(dataType, rowId)]]. @@ -699,7 +718,7 @@ trait AbstractTestUtils case LongType.LONG => rowId % 25 == 0 case FloatType.FLOAT => rowId % 5 == 0 case DoubleType.DOUBLE => rowId % 10 == 0 - case StringType.STRING => rowId % 2 == 0 + case _: StringType => rowId % 2 == 0 case BinaryType.BINARY => rowId % 3 == 0 case DateType.DATE => rowId % 5 == 0 case TimestampType.TIMESTAMP => rowId % 3 == 0 @@ -720,7 +739,7 @@ trait AbstractTestUtils case LongType.LONG => rowId * 287623L / 91 case FloatType.FLOAT => rowId * 7651.2323f / 91 case DoubleType.DOUBLE => rowId * 23423.23d / 17 - case StringType.STRING => (rowId % 19).toString + case _: StringType => (rowId % 19).toString case BinaryType.BINARY => Array[Byte]((rowId % 21).toByte, (rowId % 7 - 1).toByte) case DateType.DATE => (rowId * 28234) % 2876 case TimestampType.TIMESTAMP => (rowId * 2342342L) % 23 @@ -801,7 +820,7 @@ trait AbstractTestUtils case LongType.LONG => sparktypes.DataTypes.LongType case FloatType.FLOAT => sparktypes.DataTypes.FloatType case DoubleType.DOUBLE => sparktypes.DataTypes.DoubleType - case StringType.STRING => sparktypes.DataTypes.StringType + case _: StringType => sparktypes.DataTypes.StringType case BinaryType.BINARY => sparktypes.DataTypes.BinaryType case DateType.DATE => sparktypes.DataTypes.DateType case TimestampType.TIMESTAMP => sparktypes.DataTypes.TimestampType diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/WriteUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/WriteUtils.scala index 2aef2a1d2dc..2814b244105 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/WriteUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/WriteUtils.scala @@ -508,7 +508,7 @@ trait AbstractWriteUtils extends TestUtils with TransactionBuilderSupport { expSchema: StructType, expData: Seq[TestRow]): Unit = { val actSchema = tableSchema(path) - assert(actSchema === expSchema) + assert(actSchema.equivalentIgnoreCollations(expSchema)) // verify data using Kernel reader checkTable(path, expData)