Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ case class CometExecRule(session: SparkSession)

private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec]

private def producesArrowOutput(plan: SparkPlan): Boolean = plan match {
case _: CometNativeExec => true
case u: CometUnionExec => u.children.forall(producesArrowOutput)
case c: CometCoalesceExec => producesArrowOutput(c.child)
case _ => false
}

// spotless:off

/**
Expand Down Expand Up @@ -670,16 +677,17 @@ case class CometExecRule(session: SparkSession)
private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = {
val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
if (isOperatorEnabled(serde, op)) {
// For operators that require native children (like writes), check if all data-producing
// children are CometNativeExec. This prevents runtime failures when the native operator
// expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector).
// Operators with requiresNativeChildren (like the native parquet writer) consume Arrow
// batches from the JNI plan. Only CometNativeExec and pass-through sinks that forward
// such batches unchanged (CometUnionExec, CometCoalesceExec) are safe; other CometExec
// subclasses (CometLocalTableScanExec, CometCollectLimitExec, CometTakeOrderedAndProjectExec)
// produce row-format ColumnarBatches and would crash the native operator at runtime.
if (serde.requiresNativeChildren && op.children.nonEmpty) {
// Get the actual data-producing children (unwrap WriteFilesExec if present)
val dataProducingChildren = op.children.flatMap {
case writeFiles: WriteFilesExec => Seq(writeFiles.child)
case other => Seq(other)
}
if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) {
if (!dataProducingChildren.forall(producesArrowOutput)) {
withInfo(op, "Cannot perform native operation because input is not in Arrow format")
return None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,176 @@ class CometParquetWriterSuite extends CometTestBase {
}
}

// Test for issue #3429: CTAS with UNION fails in Spark 4.x with native writer
test("parquet write with union - CTAS style") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

// Create a DataFrame using UNION - simulating CTAS with UNION pattern
val df1 = spark.range(1, 5).toDF("id")
val df2 = spark.range(10, 15).toDF("id")
val unionDf = df1.union(df2)

// Write using parquet - this is similar to CTAS
val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath)

// Verify the write completed and data is correct
val result = spark.read.parquet(outputPath)
assert(result.count() == 9, "Expected 9 rows (4 + 5)")

// Verify native write was used
assertHasCometNativeWriteExec(plan)
}
}
}

// Corner case: UNION with multiple (3+) DataFrames
test("parquet write with multiple unions") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

val df1 = spark.range(1, 4).toDF("id")
val df2 = spark.range(10, 13).toDF("id")
val df3 = spark.range(20, 23).toDF("id")
val df4 = spark.range(30, 33).toDF("id")
val unionDf = df1.union(df2).union(df3).union(df4)

val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath)

val result = spark.read.parquet(outputPath)
assert(result.count() == 12, "Expected 12 rows (3 + 3 + 3 + 3)")

assertHasCometNativeWriteExec(plan)
}
}
}

// Corner case: UNION followed by coalesce
test("parquet write with union and coalesce") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

val df1 = spark.range(1, 50).toDF("id")
val df2 = spark.range(100, 149).toDF("id")
val unionDf = df1.union(df2).coalesce(2)

val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath)

val result = spark.read.parquet(outputPath)
assert(result.count() == 98, "Expected 98 rows (49 + 49)")

assertHasCometNativeWriteExec(plan)
}
}
}

// Corner case: UNION with filter
test("parquet write with union and filter") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

val df1 = spark.range(1, 10).toDF("id")
val df2 = spark.range(20, 30).toDF("id")
val unionDf = df1.union(df2).filter("id % 2 = 0")

val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath)

val result = spark.read.parquet(outputPath)
// Even numbers: 2,4,6,8 from df1, 20,22,24,26,28 from df2 = 9 rows
assert(result.count() == 9, "Expected 9 even rows")

assertHasCometNativeWriteExec(plan)
}
}
}

// Corner case: UNION with complex types (struct)
test("parquet write with union of structs") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

// Use parquet files as source so Comet can convert the scans to native operators.
// SQL literals produce RDDScanExec(OneRowRelation) which Comet cannot convert,
// causing the native writer to not engage.
withTempPath { srcDir =>
val src1 = new File(srcDir, "src1.parquet").getAbsolutePath
val src2 = new File(srcDir, "src2.parquet").getAbsolutePath
Seq((1, ("Alice", 30))).toDF("id", "person").write.parquet(src1)
Seq((2, ("Bob", 25))).toDF("id", "person").write.parquet(src2)

val df1 = spark.read.parquet(src1)
val df2 = spark.read.parquet(src2)
val df = df1.union(df2)

val plan = captureWritePlan(path => df.write.parquet(path), outputPath)

val result = spark.read.parquet(outputPath)
assert(result.count() == 2)

assertHasCometNativeWriteExec(plan)
}
}
}
}

// Corner case: Nested UNION (UNION inside subquery)
test("parquet write with nested union in SQL") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

// Use parquet files as source instead of SQL literals to ensure Comet
// can convert the scans to native operators.
withTempPath { srcDir =>
val src1 = new File(srcDir, "src1.parquet").getAbsolutePath
val src2 = new File(srcDir, "src2.parquet").getAbsolutePath
Seq(1, 2).toDF("id").write.parquet(src1)
Seq(3, 4).toDF("id").write.parquet(src2)

val inner1 = spark.read.parquet(src1)
val inner2 = spark.read.parquet(src2)
val df = inner1.union(inner2)

val plan = captureWritePlan(path => df.write.parquet(path), outputPath)

val result = spark.read.parquet(outputPath)
assert(result.count() == 4)

assertHasCometNativeWriteExec(plan)
}
}
}
}

test("parquet write with map type") {
withTempPath { dir =>
val outputPath = new File(dir, "output.parquet").getAbsolutePath
Expand Down Expand Up @@ -542,4 +712,32 @@ class CometParquetWriterSuite extends CometTestBase {
rows
}

test("native writer rejects non-Arrow CometExec children (regression for #3524)") {
withTempDir { dir =>
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
CometConf.COMET_OPERATOR_DATA_WRITING_COMMAND_ALLOW_INCOMPAT.key -> "true",
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {

val out = new File(dir, "literal_write").getAbsolutePath
val df = Seq((1, "a"), (2, "b")).toDF("id", "v")

val plan = captureWritePlan(p => df.write.parquet(p), out)

val hasNativeWrite = plan.exists {
case _: CometNativeWriteExec => true
case d: DataWritingCommandExec =>
d.child.exists(_.isInstanceOf[CometNativeWriteExec])
case _ => false
}
assert(
!hasNativeWrite,
s"CometNativeWriteExec must NOT wrap a CometLocalTableScanExec child:\n${plan.treeString}")

assert(spark.read.parquet(out).count() == 2)
}
}
}

}