Skip to content

Commit d344ec0

Browse files
committed
feat: route date_format JVM fallback through codegen dispatcher
CometDateFormat keeps the native to_char path for UTC sessions with a format literal in the strftime-mappable whitelist, and now routes every other case through the Arrow-direct codegen dispatcher (CometScalaUDFCodegen) so that non-UTC sessions, non-literal formats, and formats outside the whitelist stay inside the Comet pipeline running Spark's own DateFormatClass.doGenCode. Refactor: extract the closure-serialize + JvmScalarUdf-proto emission from CometScalaUDF.convert into a reusable CometScalaUDF.emitJvmCodegenDispatch helper. Any serde that wants to fall back to a Spark built-in expression through the dispatcher can call it. Gated by COMET_SCALA_UDF_CODEGEN_ENABLED so the default remains a clean Spark fallback for those cases until the dispatcher graduates from experimental. Reasoning notes: - DateFormatClass already has a proper doGenCode (not CodegenFallback), NullIntolerant, and ResolveTimeZone stamps the timeZoneId on it during analysis. Closure-serializing the bound tree therefore reproduces Spark-identical behavior for every timezone. - The kernel cache key already encodes the literal format and timezone via the serialized expression bytes, so (format, tz) combinations get distinct cached kernels just like a bespoke (format, tz) -> formatter cache would. Saves an entire DateFormatUDF.scala class. Tests: - date_format - timestamp_ntz input: now runs checkSparkAnswerAndOperator for every timezone under the codegen flag instead of falling back for non-UTC. - Split each previous "falls back to Spark" Scala test into two: one asserting the codegen-on path stays in Comet, one asserting the codegen-off path falls back with the dispatcher flag as the reason. - date_format.sql now pins a non-UTC session timezone and enables the codegen flag at file scope; all queries are plain query and assert in-Comet execution.
1 parent 28fb854 commit d344ec0

4 files changed

Lines changed: 132 additions & 89 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.apache.comet.serde
2121

2222
import org.apache.spark.SparkEnv
23-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF}
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Expression, Literal, ScalaUDF}
2424
import org.apache.spark.sql.types.BinaryType
2525

2626
import org.apache.comet.CometConf
@@ -45,15 +45,35 @@ import org.apache.comet.udf.codegen.CometScalaUDFCodegen
4545
*
4646
* Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a
4747
* `ScalaUDF` fall back to Spark for the enclosing operator.
48+
*
49+
* [[emitJvmCodegenDispatch]] exposes the same closure-serialize + dispatcher-proto path to other
50+
* serdes that want to keep a built-in Spark expression inside the Comet pipeline when no native
51+
* lowering is viable. See [[CometDateFormat]] for an example.
4852
*/
4953
object CometScalaUDF extends CometExpressionSerde[ScalaUDF] {
5054

51-
override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
55+
override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] =
56+
emitJvmCodegenDispatch(expr, inputs, binding)
57+
58+
/**
59+
* Bind `expr`, closure-serialize it, and emit a `JvmScalarUdf` proto routed through
60+
* [[CometScalaUDFCodegen]] so that native execution evaluates the expression inside the
61+
* Arrow-direct codegen dispatcher. The dispatcher will Janino-compile `expr.doGenCode` into a
62+
* batch kernel on first invocation per task.
63+
*
64+
* Returns `None` (with `withInfo` tagging the reason) when the dispatcher is disabled via
65+
* [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]] or when [[CometBatchKernelCodegen.canHandle]]
66+
* refuses the expression tree. Callers should treat `None` as a clean Spark-fallback signal.
67+
*/
68+
def emitJvmCodegenDispatch(
69+
expr: Expression,
70+
inputs: Seq[Attribute],
71+
binding: Boolean): Option[Expr] = {
5272
if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) {
5373
withInfo(
5474
expr,
55-
s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; ScalaUDF has no native path " +
56-
"so the plan falls back to Spark")
75+
s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; expression has no native " +
76+
"path so the plan falls back to Spark")
5777
return None
5878
}
5979

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 47 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
2727
import org.apache.spark.unsafe.types.UTF8String
2828

29+
import org.apache.comet.CometConf
2930
import org.apache.comet.CometSparkSessionExtensions.withInfo
3031
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3132
import org.apache.comet.serde.CometGetDateField.CometGetDateField
@@ -593,17 +594,23 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
593594
}
594595

595596
/**
596-
* Converts Spark DateFormatClass expression to DataFusion's to_char function.
597+
* Converts Spark `DateFormatClass` to DataFusion's `to_char` when format and timezone are
598+
* mappable, otherwise routes the expression through the Arrow-direct codegen dispatcher so that
599+
* Spark's own `DateFormatClass.doGenCode` runs inside the Comet pipeline.
597600
*
598-
* Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This
599-
* implementation supports a whitelist of common format strings that can be reliably mapped
600-
* between the two systems.
601+
* Routing:
602+
* - format is a literal in `supportedFormats` AND timezone is UTC -> native `to_char`
603+
* - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression
604+
* `allowIncompatible` flag set -> native `to_char` (results may differ from Spark)
605+
* - all other cases -> JVM codegen dispatcher ([[CometScalaUDF.emitJvmCodegenDispatch]]), gated
606+
* by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When that flag is disabled the operator
607+
* falls back to Spark.
601608
*/
602609
object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
603610

604611
/**
605612
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
606-
* are supported.
613+
* are supported by the native path.
607614
*/
608615
val supportedFormats: Map[String, String] = Map(
609616
// Full date formats
@@ -637,66 +644,50 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
637644
// ISO formats
638645
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")
639646

640-
override def getIncompatibleReasons(): Seq[String] = Seq(
641-
"Non-UTC timezones may produce different results than Spark")
647+
// Compatibility is decided inside `convert`: the native path covers a subset, and the codegen
648+
// dispatcher covers everything else when enabled. Plan-time tagging happens via `withInfo` on
649+
// the path that returns None.
650+
override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible()
642651

643-
override def getUnsupportedReasons(): Seq[String] = Seq(
644-
"Only the following formats are supported:" +
645-
supportedFormats.keys.toSeq.sorted
646-
.map(k => s"`$k`")
647-
.mkString("\n - ", "\n - ", ""))
648-
649-
override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
650-
// Check timezone - only UTC is fully compatible
651-
val timezone = expr.timeZoneId.getOrElse("UTC")
652-
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
653-
654-
expr.right match {
655-
case Literal(fmt: UTF8String, _) =>
656-
val format = fmt.toString
657-
if (supportedFormats.contains(format)) {
658-
if (isUtc) {
659-
Compatible()
660-
} else {
661-
Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results"))
662-
}
663-
} else {
664-
Unsupported(
665-
Some(
666-
s"Format '$format' is not supported. Supported formats: " +
667-
supportedFormats.keys.mkString(", ")))
668-
}
669-
case _ =>
670-
Unsupported(Some("Only literal format strings are supported"))
671-
}
672-
}
652+
override def getCompatibleNotes(): Seq[String] = Seq(
653+
"Format strings in a curated allow-list run natively via DataFusion's `to_char` for UTC " +
654+
"sessions. Other format strings (including non-literal formats), as well as non-UTC " +
655+
"sessions, route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct " +
656+
"codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the " +
657+
"codegen dispatcher is disabled (default) the operator falls back to Spark in those " +
658+
"cases.")
673659

674660
override def convert(
675661
expr: DateFormatClass,
676662
inputs: Seq[Attribute],
677663
binding: Boolean): Option[ExprOuterClass.Expr] = {
678-
// Get the format string - must be a literal for us to map it
679-
val strftimeFormat = expr.right match {
680-
case Literal(fmt: UTF8String, _) =>
681-
supportedFormats.get(fmt.toString)
664+
val timezone = expr.timeZoneId.getOrElse("UTC")
665+
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
666+
667+
val nativeFormat: Option[String] = expr.right match {
668+
case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString)
682669
case _ => None
683670
}
684671

685-
strftimeFormat match {
686-
case Some(format) =>
687-
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
688-
val formatExpr = exprToProtoInternal(Literal(format), inputs, binding)
689-
690-
val optExpr = scalarFunctionExprToProtoWithReturnType(
691-
"to_char",
692-
StringType,
693-
false,
694-
childExpr,
695-
formatExpr)
696-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
697-
case None =>
698-
withInfo(expr, expr.left, expr.right)
699-
None
672+
val canUseNative = nativeFormat.isDefined && {
673+
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
674+
}
675+
676+
if (canUseNative) {
677+
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
678+
val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding)
679+
val optExpr = scalarFunctionExprToProtoWithReturnType(
680+
"to_char",
681+
StringType,
682+
false,
683+
childExpr,
684+
formatExpr)
685+
optExprWithInfo(optExpr, expr, expr.left, expr.right)
686+
} else {
687+
// Hand the full `DateFormatClass` (with `timeZoneId` already stamped by `ResolveTimeZone`)
688+
// to the codegen dispatcher. It closure-serializes the bound tree, so non-UTC timezones
689+
// and non-whitelisted / non-literal format strings produce Spark-identical results.
690+
CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding)
700691
}
701692
}
702693
}

spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,27 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18+
-- Pin the session timezone so the test exercises the non-UTC path regardless of the JVM
19+
-- default. Enable the codegen dispatcher so non-UTC and non-whitelisted formats stay inside
20+
-- Comet via Spark's own DateFormatClass.doGenCode instead of falling back to Spark.
21+
-- Config: spark.sql.session.timeZone=America/Los_Angeles
22+
-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true
23+
1824
statement
1925
CREATE TABLE test_date_format(ts timestamp) USING parquet
2026

2127
statement
2228
INSERT INTO test_date_format VALUES (timestamp('2024-06-15 10:30:45')), (timestamp('1970-01-01 00:00:00')), (NULL)
2329

24-
query expect_fallback(Non-UTC timezone)
30+
query
2531
SELECT date_format(ts, 'yyyy-MM-dd') FROM test_date_format
2632

27-
query expect_fallback(Non-UTC timezone)
33+
query
2834
SELECT date_format(ts, 'HH:mm:ss') FROM test_date_format
2935

30-
query expect_fallback(Non-UTC timezone)
36+
query
3137
SELECT date_format(ts, 'yyyy-MM-dd HH:mm:ss') FROM test_date_format
3238

3339
-- literal arguments
34-
query expect_fallback(Non-UTC timezone)
40+
query
3541
SELECT date_format(timestamp('2024-06-15 10:30:45'), 'yyyy-MM-dd')

spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -214,26 +214,21 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
214214
}
215215

216216
test("date_format - timestamp_ntz input") {
217-
// TimestampNTZ is timezone-independent, so date_format should produce the same
218-
// formatted string regardless of session timezone. Comet currently only runs this
219-
// natively for UTC; for non-UTC it falls back to Spark. We verify correctness
220-
// (matching Spark's output) in all cases.
217+
// TimestampNTZ is timezone-independent, so date_format must produce the same string
218+
// regardless of session timezone. With the codegen dispatcher enabled, non-UTC sessions
219+
// stay in Comet by running Spark's own `DateFormatClass.doGenCode` via the dispatcher.
221220
val r = new Random(42)
222221
val ntzSchema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true)))
223222
val ntzDF = FuzzDataGenerator.generateDataFrame(r, spark, ntzSchema, 100, DataGenOptions())
224223
ntzDF.createOrReplaceTempView("ntz_tbl")
225224
val supportedFormats =
226225
CometDateFormat.supportedFormats.keys.toSeq.filterNot(_.contains("'"))
227-
for (tz <- crossTimezones) {
228-
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
229-
for (format <- supportedFormats) {
230-
if (tz == "UTC") {
226+
withSQLConf(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") {
227+
for (tz <- crossTimezones) {
228+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
229+
for (format <- supportedFormats) {
231230
checkSparkAnswerAndOperator(
232231
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
233-
} else {
234-
// Non-UTC falls back to Spark but should still produce correct results
235-
checkSparkAnswer(
236-
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
237232
}
238233
}
239234
}
@@ -476,45 +471,76 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
476471
}
477472
}
478473

479-
test("date_format unsupported format falls back to Spark") {
474+
test("date_format unsupported format routes via codegen dispatcher") {
480475
createTimestampTestData.createOrReplaceTempView("tbl")
481476

482-
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
483-
// Unsupported format string
477+
withSQLConf(
478+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
479+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") {
480+
checkSparkAnswerAndOperator(
481+
"SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0")
482+
}
483+
}
484+
485+
test("date_format unsupported format falls back when codegen dispatcher disabled") {
486+
createTimestampTestData.createOrReplaceTempView("tbl")
487+
488+
withSQLConf(
489+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
490+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") {
484491
checkSparkAnswerAndFallbackReason(
485492
"SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0",
486-
"Format 'yyyy-MM-dd EEEE' is not supported")
493+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key)
487494
}
488495
}
489496

490-
test("date_format with non-UTC timezone falls back to Spark") {
497+
test("date_format with non-UTC timezone routes via codegen dispatcher") {
491498
createTimestampTestData.createOrReplaceTempView("tbl")
492499

493500
val nonUtcTimezones =
494501
Seq("America/New_York", "America/Los_Angeles", "Europe/London", "Asia/Tokyo")
495502

496503
for (tz <- nonUtcTimezones) {
497-
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
498-
// Non-UTC timezones should fall back to Spark as Incompatible
504+
withSQLConf(
505+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz,
506+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true") {
507+
checkSparkAnswerAndOperator(
508+
"SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0")
509+
}
510+
}
511+
}
512+
513+
test("date_format with non-UTC timezone falls back when codegen dispatcher disabled") {
514+
createTimestampTestData.createOrReplaceTempView("tbl")
515+
516+
val nonUtcTimezones = Seq("America/New_York", "Europe/London")
517+
518+
for (tz <- nonUtcTimezones) {
519+
withSQLConf(
520+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz,
521+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") {
499522
checkSparkAnswerAndFallbackReason(
500523
"SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0",
501-
s"Non-UTC timezone '$tz' may produce different results")
524+
CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key)
502525
}
503526
}
504527
}
505528

506-
test("date_format with non-UTC timezone works when allowIncompatible is enabled") {
529+
test("date_format with non-UTC timezone takes native path when allowIncompatible is enabled") {
507530
createTimestampTestData.createOrReplaceTempView("tbl")
508531

509532
val nonUtcTimezones = Seq("America/New_York", "Europe/London", "Asia/Tokyo")
510533

511534
for (tz <- nonUtcTimezones) {
512535
withSQLConf(
513536
SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz,
514-
"spark.comet.expr.DateFormatClass.allowIncompatible" -> "true") {
515-
// With allowIncompatible enabled, Comet will execute the expression
516-
// Results may differ from Spark but should not throw errors
517-
checkSparkAnswer("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl order by c0")
537+
"spark.comet.expression.DateFormatClass.allowIncompatible" -> "true") {
538+
// Native to_char results may diverge from Spark for non-UTC timezones (the reason the
539+
// JVM UDF is the default), so we only check that execution stays inside Comet. ORDER BY
540+
// is omitted to keep the plan free of AQEShuffleRead.
541+
val df = sql("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl")
542+
df.collect()
543+
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
518544
}
519545
}
520546
}

0 commit comments

Comments
 (0)