Skip to content
This repository was archived by the owner on Oct 23, 2024. It is now read-only.

Commit bd94cf7

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-31227][SQL] Non-nullable null type in complex types should not coerce to nullable type
### What changes were proposed in this pull request? This PR targets for non-nullable null type not to coerce to nullable type in complex types. Non-nullable fields in struct, elements in an array and entries in map can mean empty array, struct and map. They are empty so it does not need to force the nullability when we find common types. This PR also reverts and supersedes apache@d7b97a1 ### Why are the changes needed? To make type coercion coherent and consistent. Currently, we correctly keep the nullability even between non-nullable fields: ```scala import org.apache.spark.sql.types._ import org.apache.spark.sql.functions._ spark.range(1).select(array(lit(1)).cast(ArrayType(IntegerType, false))).printSchema() spark.range(1).select(array(lit(1)).cast(ArrayType(DoubleType, false))).printSchema() ``` ```scala spark.range(1).selectExpr("concat(array(1), array(1)) as arr").printSchema() ``` ### Does this PR introduce any user-facing change? Yes. ```scala import org.apache.spark.sql.types._ import org.apache.spark.sql.functions._ spark.range(1).select(array().cast(ArrayType(IntegerType, false))).printSchema() ``` ```scala spark.range(1).selectExpr("concat(array(), array(1)) as arr").printSchema() ``` **Before:** ``` org.apache.spark.sql.AnalysisException: cannot resolve 'array()' due to data type mismatch: cannot cast array<null> to array<int>;; 'Project [cast(array() as array<int>) AS array()#68] +- Range (0, 1, step=1, splits=Some(12)) at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:149) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$$nestedInanonfun$checkAnalysis$1$2.applyOrElse(CheckAnalysis.scala:140) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:333) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:333) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237) ``` ``` root |-- arr: array (nullable = false) | |-- element: integer (containsNull = true) ``` **After:** ``` root |-- array(): array (nullable = false) | |-- element: integer (containsNull = false) ``` ``` root |-- arr: array (nullable = false) | |-- element: integer (containsNull = false) ``` ### How was this patch tested? Unittests were added and manually tested. Closes apache#27991 from HyukjinKwon/SPARK-31227. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 3bd10ce) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 557623b commit bd94cf7

File tree

5 files changed

+73
-29
lines changed

5 files changed

+73
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ object TypeCoercion {
160160
}
161161
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
162162
findTypeFunc(kt1, kt2)
163-
.filter { kt => Cast.canCastMapKeyNullSafe(kt1, kt) && Cast.canCastMapKeyNullSafe(kt2, kt) }
163+
.filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) }
164164
.flatMap { kt =>
165165
findTypeFunc(vt1, vt2).map { vt =>
166166
MapType(kt, vt, valueContainsNull1 || valueContainsNull2 ||

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

+9-8
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ object Cast {
7777
resolvableNullability(fn || forceNullable(fromType, toType), tn)
7878

7979
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
80-
canCast(fromKey, toKey) && canCastMapKeyNullSafe(fromKey, toKey) &&
80+
canCast(fromKey, toKey) &&
81+
(!forceNullable(fromKey, toKey)) &&
8182
canCast(fromValue, toValue) &&
8283
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
8384

@@ -97,11 +98,6 @@ object Cast {
9798
case _ => false
9899
}
99100

100-
def canCastMapKeyNullSafe(fromType: DataType, toType: DataType): Boolean = {
101-
// If the original map key type is NullType, it's OK as the map must be empty.
102-
fromType == NullType || !forceNullable(fromType, toType)
103-
}
104-
105101
/**
106102
* Return true if we need to use the `timeZone` information casting `from` type to `to` type.
107103
* The patterns matched reflect the current implementation in the Cast node.
@@ -210,8 +206,13 @@ object Cast {
210206
case _ => false // overflow
211207
}
212208

209+
/**
210+
* Returns `true` if casting non-nullable values from `from` type to `to` type
211+
* may return null. Note that the caller side should take care of input nullability
212+
* first and only call this method if the input is not nullable.
213+
*/
213214
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
214-
case (NullType, _) => true
215+
case (NullType, _) => false // empty array or map case
215216
case (_, _) if from == to => false
216217

217218
case (StringType, BinaryType) => false
@@ -269,7 +270,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
269270
}
270271
}
271272

272-
override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable
273+
override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType)
273274

274275
protected def ansiEnabled: Boolean
275276

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

+34-20
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.types._
3030
import org.apache.spark.unsafe.types.CalendarInterval
3131

3232
class TypeCoercionSuite extends AnalysisTest {
33+
import TypeCoercionSuite._
3334

3435
// scalastyle:off line.size.limit
3536
// The following table shows all implicit data type conversions that are not visible to the user.
@@ -99,22 +100,6 @@ class TypeCoercionSuite extends AnalysisTest {
99100
case _ => Literal.create(null, dataType)
100101
}
101102

102-
val integralTypes: Seq[DataType] =
103-
Seq(ByteType, ShortType, IntegerType, LongType)
104-
val fractionalTypes: Seq[DataType] =
105-
Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
106-
val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
107-
val atomicTypes: Seq[DataType] =
108-
numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
109-
val complexTypes: Seq[DataType] =
110-
Seq(ArrayType(IntegerType),
111-
ArrayType(StringType),
112-
MapType(StringType, StringType),
113-
new StructType().add("a1", StringType),
114-
new StructType().add("a1", StringType).add("a2", IntegerType))
115-
val allTypes: Seq[DataType] =
116-
atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)
117-
118103
// Check whether the type `checkedType` can be cast to all the types in `castableTypes`,
119104
// but cannot be cast to the other types in `allTypes`.
120105
private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = {
@@ -497,6 +482,23 @@ class TypeCoercionSuite extends AnalysisTest {
497482
.add("null", IntegerType, nullable = false),
498483
Some(new StructType()
499484
.add("null", IntegerType, nullable = true)))
485+
486+
widenTest(
487+
ArrayType(NullType, containsNull = false),
488+
ArrayType(IntegerType, containsNull = false),
489+
Some(ArrayType(IntegerType, containsNull = false)))
490+
491+
widenTest(MapType(NullType, NullType, false),
492+
MapType(IntegerType, StringType, false),
493+
Some(MapType(IntegerType, StringType, false)))
494+
495+
widenTest(
496+
new StructType()
497+
.add("null", NullType, nullable = false),
498+
new StructType()
499+
.add("null", IntegerType, nullable = false),
500+
Some(new StructType()
501+
.add("null", IntegerType, nullable = false)))
500502
}
501503

502504
test("wider common type for decimal and array") {
@@ -728,8 +730,6 @@ class TypeCoercionSuite extends AnalysisTest {
728730
}
729731

730732
test("cast NullType for expressions that implement ExpectsInputTypes") {
731-
import TypeCoercionSuite._
732-
733733
ruleTest(TypeCoercion.ImplicitTypeCasts,
734734
AnyTypeUnaryExpression(Literal.create(null, NullType)),
735735
AnyTypeUnaryExpression(Literal.create(null, NullType)))
@@ -740,8 +740,6 @@ class TypeCoercionSuite extends AnalysisTest {
740740
}
741741

742742
test("cast NullType for binary operators") {
743-
import TypeCoercionSuite._
744-
745743
ruleTest(TypeCoercion.ImplicitTypeCasts,
746744
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
747745
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
@@ -1548,6 +1546,22 @@ class TypeCoercionSuite extends AnalysisTest {
15481546

15491547
object TypeCoercionSuite {
15501548

1549+
val integralTypes: Seq[DataType] =
1550+
Seq(ByteType, ShortType, IntegerType, LongType)
1551+
val fractionalTypes: Seq[DataType] =
1552+
Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2))
1553+
val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes
1554+
val atomicTypes: Seq[DataType] =
1555+
numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType)
1556+
val complexTypes: Seq[DataType] =
1557+
Seq(ArrayType(IntegerType),
1558+
ArrayType(StringType),
1559+
MapType(StringType, StringType),
1560+
new StructType().add("a1", StringType),
1561+
new StructType().add("a1", StringType).add("a2", IntegerType))
1562+
val allTypes: Seq[DataType] =
1563+
atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType)
1564+
15511565
case class AnyTypeUnaryExpression(child: Expression)
15521566
extends UnaryExpression with ExpectsInputTypes with Unevaluable {
15531567
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

+22
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite
2626
import org.apache.spark.sql.Row
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence
29+
import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite
2930
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
3031
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
3132
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
@@ -413,6 +414,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
413414
assert(ret.resolved)
414415
checkEvaluation(ret, Seq(null, true, false, null))
415416
}
417+
418+
{
419+
val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = false))
420+
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
421+
assert(ret.resolved)
422+
checkEvaluation(ret, Seq.empty)
423+
}
424+
416425
{
417426
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
418427
assert(ret.resolved === false)
@@ -1158,6 +1167,19 @@ class CastSuite extends CastSuiteBase {
11581167
StructType(StructField("a", IntegerType, true) :: Nil)))
11591168
}
11601169

1170+
test("SPARK-31227: Non-nullable null type should not coerce to nullable type") {
1171+
TypeCoercionSuite.allTypes.foreach { t =>
1172+
assert(Cast.canCast(ArrayType(NullType, false), ArrayType(t, false)))
1173+
1174+
assert(Cast.canCast(
1175+
MapType(NullType, NullType, false), MapType(t, t, false)))
1176+
1177+
assert(Cast.canCast(
1178+
StructType(StructField("a", NullType, false) :: Nil),
1179+
StructType(StructField("a", t, false) :: Nil)))
1180+
}
1181+
}
1182+
11611183
test("Cast should output null for invalid strings when ANSI is not enabled.") {
11621184
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
11631185
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

+7
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
15331533
assert(e.getMessage.contains("string, binary or array"))
15341534
}
15351535

1536+
test("SPARK-31227: Non-nullable null type should not coerce to nullable type in concat") {
1537+
val actual = spark.range(1).selectExpr("concat(array(), array(1)) as arr")
1538+
val expected = spark.range(1).selectExpr("array(1) as arr")
1539+
checkAnswer(actual, expected)
1540+
assert(actual.schema === expected.schema)
1541+
}
1542+
15361543
test("flatten function") {
15371544
// Test cases with a primitive type
15381545
val intDF = Seq(

0 commit comments

Comments
 (0)