Skip to content

Commit 5976fec

Browse files
authored
#80 Add enforceTypeOnNullTypeFields function to DataFrameImplicits (#81)
1 parent 8e2334f commit 5976fec

File tree

4 files changed

+269
-5
lines changed

4 files changed

+269
-5
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ _StructTypeImplicits_ provides implicit methods for working with StructType obje
332332
dataFrame.withColumnIfDoesNotExist(path)
333333
```
334334

335+
5. Casts all `NullType` fields of the DataFrame to their corresponding types in targetSchema.
336+
337+
```scala
338+
dataFrame.enforceTypeOnNullTypeFields(targetSchema)
339+
```
335340

336341

337342
### Spark Version Guard

spark-commons/src/main/scala/za/co/absa/spark/commons/implicits/DataFrameImplicits.scala

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616

1717
package za.co.absa.spark.commons.implicits
1818

19-
import java.io.ByteArrayOutputStream
20-
21-
import org.apache.spark.sql.types.StructType
19+
import org.apache.spark.sql.functions.{col, lit, struct}
20+
import org.apache.spark.sql.types.{ArrayType, NullType, StructType}
2221
import org.apache.spark.sql.{Column, DataFrame}
22+
import za.co.absa.spark.commons.adapters.TransformAdapter
2323
import za.co.absa.spark.commons.implicits.StructTypeImplicits.DataFrameSelector
2424

25+
import java.io.ByteArrayOutputStream
26+
2527
object DataFrameImplicits {
2628

27-
implicit class DataFrameEnhancements(val df: DataFrame) extends AnyVal {
29+
implicit class DataFrameEnhancements(val df: DataFrame) extends TransformAdapter {
2830

2931
private def gatherData(showFnc: () => Unit): String = {
3032
val outCapture = new ByteArrayOutputStream
@@ -83,6 +85,70 @@ object DataFrameImplicits {
8385
}
8486
}
8587

88+
/**
89+
* Casts NullType (aka VOID) fields to their target types as defined in `targetSchema`.
90+
*
91+
* Matching of fields is "by name", so order of fields in schema(s) doesn't matter.
92+
* Resulting DataFrame has the same fields order as original DataFrame.
93+
* All fields from original DataFrame are kept in resulting DataFrame, even those that are not in `targetSchema`.
94+
*
95+
* @param targetSchema definition of field types to which potential NullTypes will be casted to
96+
* @return DataFrame with fields of NullType casted to their type in `targetSchema`
97+
*/
98+
def enforceTypeOnNullTypeFields(targetSchema: StructType): DataFrame = {
99+
100+
def processArray(
101+
thisArrType: ArrayType, targetArrType: ArrayType, thisArrayColumn: Column
102+
): Column =
103+
(thisArrType.elementType, targetArrType.elementType) match {
104+
case (_: NullType, _: NullType) => thisArrayColumn
105+
case (_: NullType, _) =>
106+
transform(
107+
thisArrayColumn,
108+
_ => lit(null).cast(targetArrType.elementType)
109+
)
110+
case (thisNestedArrType: ArrayType, targetNestedArrType: ArrayType) =>
111+
transform(
112+
thisArrayColumn,
113+
processArray(thisNestedArrType, targetNestedArrType, _)
114+
)
115+
case (thisNestedStructType: StructType, targetNestedStructType: StructType) =>
116+
transform(
117+
thisArrayColumn,
118+
element => struct(processStruct(thisNestedStructType, targetNestedStructType, Some(element)): _*)
119+
)
120+
case _ => thisArrayColumn
121+
}
122+
123+
def processStruct(
124+
currentThisSchema: StructType, currentTargetSchema: StructType, parent: Option[Column]
125+
): List[Column] = {
126+
val currentTargetSchemaMap = currentTargetSchema.map(f => (f.name.toLowerCase, f)).toMap
127+
128+
currentThisSchema.foldRight(List.empty[Column])((field, acc) => {
129+
val currentColumn: Column = parent
130+
.map(_.getField(field.name))
131+
.getOrElse(col(field.name))
132+
val correspondingTargetType = currentTargetSchemaMap.get(field.name.toLowerCase).map(_.dataType)
133+
134+
val castedColumn = (field.dataType, correspondingTargetType) match {
135+
case (NullType, Some(NullType)) => currentColumn
136+
case (NullType, Some(targetType)) => currentColumn.cast(targetType)
137+
case (arrType: ArrayType, Some(targetArrType: ArrayType)) =>
138+
processArray(arrType, targetArrType, currentColumn)
139+
case (structType: StructType, Some(targetStructType: StructType)) =>
140+
struct(processStruct(structType, targetStructType, Some(currentColumn)): _*)
141+
case _ => currentColumn
142+
}
143+
castedColumn.as(field.name) :: acc
144+
})
145+
}
146+
147+
val thisSchema = df.schema
148+
val selector = processStruct(thisSchema, targetSchema, None)
149+
df.select(selector: _*)
150+
}
151+
86152
/**
87153
* Using utils selector aligns the utils of a DataFrame to the selector
88154
* for operations where utils order might be important (e.g. hashing the whole rows and using except)

spark-commons/src/test/scala/za/co/absa/spark/commons/implicits/DataFrameImplicitsTest.scala

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
package za.co.absa.spark.commons.implicits
1818

19-
import org.apache.spark.sql.functions.lit
19+
import org.apache.spark.sql.functions.{array, lit, struct}
20+
import org.apache.spark.sql.types._
2021
import org.apache.spark.sql.{AnalysisException, DataFrame}
2122
import org.scalatest.funsuite.AnyFunSuite
2223
import za.co.absa.spark.commons.test.SparkTestBase
@@ -260,6 +261,197 @@ class DataFrameImplicitsTest extends AnyFunSuite with SparkTestBase with JsonTes
260261
}
261262
}
262263

264+
test("cast NullTypes to corresponding types by enforceTypeOnNullTypeFields") {
265+
val dfWithComplexTypes = spark.read.json(Seq(jsonF).toDS)
266+
.withColumn("nullShouldBeString", lit(null))
267+
.withColumn("nullShouldBeInteger", lit(null))
268+
.withColumn("nullShouldBeArrayOfIntegers", lit(null))
269+
.withColumn("nullShouldBeArrayOfArraysOfIntegers", lit(null))
270+
.withColumn("nullShouldBeArrayOfStructs", lit(null))
271+
.withColumn("nullShouldBeStruct", lit(null))
272+
.withColumn("shouldIgnoreNonNullTypeMismatch", lit("abc"))
273+
.withColumn("arrayOfNullShouldBeArrayOfIntegers", array(lit(null), lit(null)))
274+
.withColumn(
275+
"arrayOfArrayOfNullShouldBeArrayOfArrayOfStrings",
276+
array(array(lit(null), lit(null)))
277+
)
278+
.withColumn(
279+
"arrayOfStructs",
280+
array(
281+
struct(
282+
lit(null).as("nullShouldBeString"),
283+
lit(null).as("nullShouldBeInteger"),
284+
lit("abc").as("shouldIgnoreNonNullTypeMismatch"),
285+
array(lit(null), lit(null)).as("arrayOfNullShouldBeArrayOfIntegers")
286+
)
287+
)
288+
)
289+
.withColumn(
290+
"complexStruct",
291+
struct(
292+
lit(null).as("nullShouldBeString"),
293+
lit(null).as("nullShouldBeInteger"),
294+
lit("abc").as("shouldIgnoreNonNullTypeMismatch"),
295+
array(lit(1), lit(2), lit(3)).as("shouldIgnoreNonNullArrayTypeMismatch"),
296+
struct(
297+
lit(null).as("nullShouldBeString"),
298+
lit(null).as("nullShouldBeInteger"),
299+
lit("abc").as("shouldIgnoreNonNullTypeMismatch")
300+
).as("nestedStruct"),
301+
array(lit(null), lit(null)).as("arrayOfNullShouldBeArrayOfIntegers")
302+
)
303+
)
304+
305+
val targetSchema = StructType(
306+
Seq(
307+
StructField("id", LongType),
308+
// nullShouldBeInteger and nullShouldBeString are swapped comparing to dfWithComplexTypes
309+
// to ensure enforceTypeOnNullTypeFields converts by name
310+
StructField("nullShouldBeInteger", IntegerType),
311+
StructField("nullShouldBeString", StringType),
312+
// checks case-insensitivity
313+
StructField("nullShouldBeArrayOfINTEGERS", ArrayType(IntegerType)),
314+
StructField("nullShouldBeArrayOfArraysOfIntegers", ArrayType(ArrayType(IntegerType))),
315+
StructField(
316+
"nullShouldBeArrayOfStructs",
317+
ArrayType(
318+
StructType(
319+
Seq(StructField("a", StringType), StructField("b", DecimalType(28, 8)))
320+
)
321+
)
322+
),
323+
StructField(
324+
"nullShouldBeStruct",
325+
StructType(
326+
Seq(StructField("a", StringType), StructField("b", DecimalType(28, 8)))
327+
)
328+
),
329+
StructField("shouldIgnoreNonNullTypeMismatch", IntegerType, false),
330+
StructField("arrayOfNullShouldBeArrayOfIntegers", ArrayType(IntegerType), false),
331+
StructField("arrayOfArrayOfNullShouldBeArrayOfArrayOfStrings", ArrayType(ArrayType(StringType), false), false),
332+
StructField(
333+
"arrayOfStructs",
334+
ArrayType(
335+
StructType(
336+
Seq(
337+
// nullShouldBeInteger and nullShouldBeString are swapped comparing to dfWithComplexTypes
338+
// to ensure enforceTypeOnNullTypeFields converts by name
339+
StructField("nullShouldBeInteger", IntegerType),
340+
StructField("nullShouldBeString", StringType),
341+
StructField("shouldIgnoreNonNullTypeMismatch", IntegerType, false),
342+
StructField("arrayOfNullShouldBeArrayOfIntegers", ArrayType(IntegerType), false)
343+
)
344+
)
345+
),
346+
false
347+
),
348+
StructField(
349+
"complexStruct",
350+
StructType(
351+
Seq(
352+
// nullShouldBeInteger and nullShouldBeString are swapped comparing to dfWithComplexTypes
353+
// to ensure enforceTypeOnNullTypeFields converts by name
354+
StructField("nullShouldBeInteger", IntegerType),
355+
StructField("nullShouldBeString", StringType),
356+
// checks case-insensitivity
357+
StructField("shouldIgnoreNonNullTypeMISMATCH", IntegerType, false),
358+
StructField("shouldIgnoreNonNullArrayTypeMismatch", ArrayType(StringType, false), false),
359+
StructField(
360+
"nestedStruct",
361+
StructType(
362+
Seq(
363+
// nullShouldBeInteger and nullShouldBeString are swapped comparing to dfWithComplexTypes
364+
// to ensure enforceTypeOnNullTypeFields converts by name
365+
StructField("nullShouldBeInteger", IntegerType),
366+
// checks case-insensitivity
367+
StructField("nullSHOULDBeString", StringType),
368+
StructField("shouldIgnoreNonNullTypeMismatch", IntegerType, false)
369+
)
370+
),
371+
false
372+
),
373+
StructField("arrayOfNullShouldBeArrayOfIntegers", ArrayType(IntegerType), false)
374+
)
375+
),
376+
false
377+
)
378+
)
379+
)
380+
381+
val actual = dfWithComplexTypes.enforceTypeOnNullTypeFields(targetSchema)
382+
383+
assert(actual.count() === dfWithComplexTypes.count())
384+
385+
val actualSchema = actual.schema
386+
val expectedSchema = StructType(
387+
Seq(
388+
StructField("id", LongType),
389+
StructField("nullShouldBeString", StringType),
390+
StructField("nullShouldBeInteger", IntegerType),
391+
StructField("nullShouldBeArrayOfIntegers", ArrayType(IntegerType)),
392+
StructField("nullShouldBeArrayOfArraysOfIntegers", ArrayType(ArrayType(IntegerType))),
393+
StructField(
394+
"nullShouldBeArrayOfStructs",
395+
ArrayType(
396+
StructType(
397+
Seq(StructField("a", StringType), StructField("b", DecimalType(28, 8)))
398+
)
399+
)
400+
),
401+
StructField(
402+
"nullShouldBeStruct",
403+
StructType(
404+
Seq(StructField("a", StringType), StructField("b", DecimalType(28, 8)))
405+
)
406+
),
407+
StructField("shouldIgnoreNonNullTypeMismatch", StringType, false),
408+
StructField("arrayOfNullShouldBeArrayOfIntegers", ArrayType(IntegerType), false),
409+
StructField("arrayOfArrayOfNullShouldBeArrayOfArrayOfStrings", ArrayType(ArrayType(StringType), false), false),
410+
StructField(
411+
"arrayOfStructs",
412+
ArrayType(
413+
StructType(
414+
Seq(
415+
StructField("nullShouldBeString", StringType),
416+
StructField("nullShouldBeInteger", IntegerType),
417+
StructField("shouldIgnoreNonNullTypeMismatch", StringType, false),
418+
StructField("arrayOfNullShouldBeArrayOfIntegers", ArrayType(IntegerType), false)
419+
)
420+
),
421+
false
422+
),
423+
false
424+
),
425+
StructField(
426+
"complexStruct",
427+
StructType(
428+
Seq(
429+
StructField("nullShouldBeString", StringType),
430+
StructField("nullShouldBeInteger", IntegerType),
431+
StructField("shouldIgnoreNonNullTypeMismatch", StringType, false),
432+
StructField("shouldIgnoreNonNullArrayTypeMismatch", ArrayType(IntegerType, false), false),
433+
StructField(
434+
"nestedStruct",
435+
StructType(
436+
Seq(
437+
StructField("nullShouldBeString", StringType),
438+
StructField("nullShouldBeInteger", IntegerType),
439+
StructField("shouldIgnoreNonNullTypeMismatch", StringType, false)
440+
)
441+
),
442+
false
443+
),
444+
StructField("arrayOfNullShouldBeArrayOfIntegers", ArrayType(IntegerType), false)
445+
)
446+
),
447+
false
448+
)
449+
)
450+
)
451+
452+
assert(actualSchema === expectedSchema)
453+
}
454+
263455
test("Check that cacheIfNotCachedYet caches the data") {
264456
//Verify check test procedure
265457
val dfA = spark.read.json(Seq(jsonA).toDS)

spark-commons/src/test/scala/za/co/absa/spark/commons/implicits/JsonTestData.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ trait JsonTestData {
2525
protected val jsonC = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": "2"}, "alfa": "1"} }]"""
2626
protected val jsonD = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1} }]"""
2727
protected val jsonE = """[{"legs":[{"legid":100,"conditions":[{"amount":100,"checks":[{"checkNums":["1","2","3b","4","5c","6"]}]}]}],"id":1, "key" : {"beta": {"beta2": 2}, "alfa": 1}, "extra" : "a"}]"""
28+
protected val jsonF = """[{"id":1}, {"id":2}]"""
2829

2930
protected val sample =
3031
"""{"id":1,"legs":[{"legid":100,"conditions":[{"checks":[{"checkNums":["1","2","3b","4","5c","6"]}],"amount":100}]}]}""" ::

0 commit comments

Comments
 (0)