Skip to content

Commit 22b7bed

Browse files
authored
fix: re-enable tests skipped for Spark 4.1 (issue #4098) (#4253)
1 parent 9feb58c commit 22b7bed

8 files changed

Lines changed: 24 additions & 26 deletions

File tree

native/common/src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub enum SparkError {
7878
#[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
7979
DivideByZero,
8080

81-
#[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
81+
#[error("[REMAINDER_BY_ZERO] Remainder by zero. Use `try_mod` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
8282
RemainderByZero,
8383

8484
#[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")]

native/spark-expr/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,7 @@ pub(crate) fn decimal_sum_overflow_error(function_name: &str) -> SparkError {
132132
pub(crate) fn divide_by_zero_error() -> SparkError {
133133
SparkError::DivideByZero
134134
}
135+
136+
pub(crate) fn remainder_by_zero_error() -> SparkError {
137+
SparkError::RemainderByZero
138+
}

native/spark-expr/src/math_funcs/modulo_expr.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::{create_comet_physical_fun, IfExpr};
19-
use crate::{divide_by_zero_error, Cast, EvalMode, SparkCastOptions};
19+
use crate::{remainder_by_zero_error, Cast, EvalMode, SparkCastOptions};
2020
use arrow::compute::kernels::numeric::rem;
2121
use arrow::datatypes::*;
2222
use datafusion::common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
@@ -56,8 +56,8 @@ pub fn spark_modulo(args: &[ColumnarValue], fail_on_error: bool) -> Result<Colum
5656
match apply(lhs, rhs, rem) {
5757
Ok(result) => Ok(result),
5858
Err(e) if e.to_string().contains("Divide by zero") && fail_on_error => {
59-
// Return Spark-compliant divide by zero error.
60-
Err(divide_by_zero_error().into())
59+
// Return Spark-compliant remainder by zero error.
60+
Err(remainder_by_zero_error().into())
6161
}
6262
Err(e) => Err(e),
6363
}
@@ -245,7 +245,7 @@ mod tests {
245245
assert!(
246246
error
247247
.to_string()
248-
.contains("[DIVIDE_BY_ZERO] Division by zero"),
248+
.contains("[REMAINDER_BY_ZERO] Remainder by zero"),
249249
"Error message did not match. Actual message: {error}"
250250
);
251251
}

spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,7 @@ trait ShimSparkErrorConverter {
7878
Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context)))
7979

8080
case "RemainderByZero" =>
81-
Some(
82-
new SparkException(
83-
errorClass = "REMAINDER_BY_ZERO",
84-
messageParameters = params.map { case (k, v) => (k, v.toString) },
85-
cause = null))
81+
Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context)))
8682

8783
case "IntervalDividedByZero" =>
8884
Some(QueryExecutionErrors.intervalDividedByZeroError(sqlCtx(context)))

spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,7 @@ trait ShimSparkErrorConverter {
7878
Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context)))
7979

8080
case "RemainderByZero" =>
81-
Some(
82-
new SparkException(
83-
errorClass = "REMAINDER_BY_ZERO",
84-
messageParameters = params.map { case (k, v) => (k, v.toString) },
85-
cause = null))
81+
Some(QueryExecutionErrors.divideByZeroError(sqlCtx(context)))
8682

8783
case "IntervalDividedByZero" =>
8884
Some(QueryExecutionErrors.intervalDividedByZeroError(sqlCtx(context)))

spark/src/main/spark-4.x/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupporte
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.types.UTF8String
3232

33+
import org.apache.comet.CometSparkSessionExtensions.isSpark41Plus
34+
3335
object ShimSparkErrorConverter {
3436
val ObjectLocationPattern: Regex = "Object at location (.+?) not found".r
3537
}
@@ -88,12 +90,15 @@ trait ShimSparkErrorConverter {
8890
Some(QueryExecutionErrors.divideByZeroError(context.headOption.orNull))
8991

9092
case "RemainderByZero" =>
91-
// SPARK 4.0 REMOVED remainderByZeroError so we use generic arithmetic exception
92-
Some(
93-
new SparkException(
94-
errorClass = "REMAINDER_BY_ZERO",
95-
messageParameters = params.map { case (k, v) => (k, v.toString) },
96-
cause = null))
93+
if (isSpark41Plus) {
94+
Some(
95+
new SparkException(
96+
errorClass = "REMAINDER_BY_ZERO",
97+
messageParameters = Map("config" -> "\"spark.sql.ansi.enabled\""),
98+
cause = null))
99+
} else {
100+
Some(QueryExecutionErrors.divideByZeroError(context.headOption.orNull))
101+
}
97102

98103
case "IntervalDividedByZero" =>
99104
Some(QueryExecutionErrors.intervalDividedByZeroError(context.headOption.orNull))

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions.{col, monotonically_increasing_id}
3333
import org.apache.spark.sql.internal.SQLConf
3434
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType}
3535

36-
import org.apache.comet.CometSparkSessionExtensions.isSpark41Plus
3736
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3837
import org.apache.comet.rules.CometScanTypeChecker
3938
import org.apache.comet.serde.{Compatible, Incompatible}
@@ -525,7 +524,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
525524
}
526525

527526
test("cast FloatType to TimestampType") {
528-
assume(!isSpark41Plus, "https://github.com/apache/datafusion-comet/issues/4098")
529527
representativeTimezones.foreach { tz =>
530528
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
531529
// Use useDFDiff to avoid collect() which fails on extreme timestamp values
@@ -591,7 +589,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
591589
}
592590

593591
test("cast DoubleType to TimestampType") {
594-
assume(!isSpark41Plus, "https://github.com/apache/datafusion-comet/issues/4098")
595592
representativeTimezones.foreach { tz =>
596593
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
597594
// Use useDFDiff to avoid collect() which fails on extreme timestamp values
@@ -1568,7 +1565,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15681565
}
15691566

15701567
test("cast ArrayType to ArrayType") {
1571-
assume(!isSpark41Plus, "https://github.com/apache/datafusion-comet/issues/4098")
15721568
val types = Seq(
15731569
BooleanType,
15741570
StringType,

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1982,7 +1982,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
19821982
}
19831983

19841984
test("remainder function") {
1985-
assume(!isSpark41Plus, "https://github.com/apache/datafusion-comet/issues/4098")
19861985
def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = {
19871986
withSQLConf(
19881987
SQLConf.ANSI_ENABLED.key -> enabled.toString,
@@ -1992,6 +1991,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
19921991

19931992
def verifyResult(query: String): Unit = {
19941993
// Spark 4.1 introduced REMAINDER_BY_ZERO; older versions raise DIVIDE_BY_ZERO for `%`.
1994+
// Comet always raises REMAINDER_BY_ZERO natively but the JVM shim maps it to
1995+
// DIVIDE_BY_ZERO on Spark < 4.1 (where that error class does not exist).
19951996
val expectedError =
19961997
if (isSpark41Plus)
19971998
"[REMAINDER_BY_ZERO] Remainder by zero. Use `try_mod` to tolerate divisor being 0 and return NULL instead."

0 commit comments

Comments
 (0)