Skip to content

Commit 0511026

Browse files
committed
fix: fix error reporting for string to timestamp cast
1 parent f49e4b6 commit 0511026

8 files changed

Lines changed: 220 additions & 59 deletions

File tree

native/spark-expr/src/conversion_funcs/string.rs

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,29 @@ macro_rules! cast_utf8_to_timestamp {
3737
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{
3838
let len = $array.len();
3939
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
40+
let mut cast_err: Option<SparkError> = None;
4041
for i in 0..len {
4142
if $array.is_null(i) {
4243
cast_array.append_null()
43-
} else if let Ok(Some(cast_value)) =
44-
$cast_method($array.value(i).trim(), $eval_mode, $tz)
45-
{
46-
cast_array.append_value(cast_value);
4744
} else {
48-
cast_array.append_null()
45+
match $cast_method($array.value(i).trim(), $eval_mode, $tz) {
46+
Ok(Some(cast_value)) => cast_array.append_value(cast_value),
47+
Ok(None) => cast_array.append_null(),
48+
Err(e) => {
49+
if $eval_mode == EvalMode::Ansi {
50+
cast_err = Some(e);
51+
break;
52+
}
53+
cast_array.append_null()
54+
}
55+
}
4956
}
5057
}
51-
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
52-
result
58+
if let Some(e) = cast_err {
59+
Err(e)
60+
} else {
61+
Ok(Arc::new(cast_array.finish()) as ArrayRef)
62+
}
5363
}};
5464
}
5565

@@ -668,15 +678,13 @@ pub(crate) fn cast_string_to_timestamp(
668678
let tz = &timezone::Tz::from_str(timezone_str).unwrap();
669679

670680
let cast_array: ArrayRef = match to_type {
671-
DataType::Timestamp(_, _) => {
672-
cast_utf8_to_timestamp!(
673-
string_array,
674-
eval_mode,
675-
TimestampMicrosecondType,
676-
timestamp_parser,
677-
tz
678-
)
679-
}
681+
DataType::Timestamp(_, _) => cast_utf8_to_timestamp!(
682+
string_array,
683+
eval_mode,
684+
TimestampMicrosecondType,
685+
timestamp_parser,
686+
tz
687+
)?,
680688
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
681689
};
682690
Ok(cast_array)
@@ -1004,7 +1012,7 @@ fn get_timestamp_values<T: TimeZone>(
10041012
.with_second(second)
10051013
.with_microsecond(microsecond),
10061014
_ => {
1007-
return Err(SparkError::CastInvalidValue {
1015+
return Err(SparkError::InvalidInputInCastToDatetime {
10081016
value: value.to_string(),
10091017
from_type: "STRING".to_string(),
10101018
to_type: "TIMESTAMP".to_string(),
@@ -1095,31 +1103,31 @@ fn timestamp_parser<T: TimeZone>(
10951103
// Define regex patterns and corresponding parsing functions
10961104
let patterns = &[
10971105
(
1098-
Regex::new(r"^\d{4,5}$").unwrap(),
1106+
Regex::new(r"^\d{4,7}$").unwrap(),
10991107
parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult<Option<i64>>,
11001108
),
11011109
(
1102-
Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
1110+
Regex::new(r"^\d{4,7}-\d{2}$").unwrap(),
11031111
parse_str_to_month_timestamp,
11041112
),
11051113
(
1106-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
1114+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}$").unwrap(),
11071115
parse_str_to_day_timestamp,
11081116
),
11091117
(
1110-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
1118+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
11111119
parse_str_to_hour_timestamp,
11121120
),
11131121
(
1114-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
1122+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
11151123
parse_str_to_minute_timestamp,
11161124
),
11171125
(
1118-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
1126+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
11191127
parse_str_to_second_timestamp,
11201128
),
11211129
(
1122-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
1130+
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
11231131
parse_str_to_microsecond_timestamp,
11241132
),
11251133
(
@@ -1140,7 +1148,7 @@ fn timestamp_parser<T: TimeZone>(
11401148

11411149
if timestamp.is_none() {
11421150
return if eval_mode == EvalMode::Ansi {
1143-
Err(SparkError::CastInvalidValue {
1151+
Err(SparkError::InvalidInputInCastToDatetime {
11441152
value: value.to_string(),
11451153
from_type: "STRING".to_string(),
11461154
to_type: "TIMESTAMP".to_string(),
@@ -1204,15 +1212,15 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
12041212
fn is_valid_digits(segment: i32, digits: usize) -> bool {
12051213
// An integer is able to represent a date within [+-]5 million years.
12061214
let max_digits_year = 7;
1207-
//year (segment 0) can be between 4 to 7 digits,
1208-
//month and day (segment 1 and 2) can be between 1 to 2 digits
1215+
// year (segment 0) can be between 4 to 7 digits,
1216+
// month and day (segment 1 and 2) can be between 1 to 2 digits
12091217
(segment == 0 && digits >= 4 && digits <= max_digits_year)
12101218
|| (segment != 0 && digits > 0 && digits <= 2)
12111219
}
12121220

12131221
fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
12141222
if eval_mode == EvalMode::Ansi {
1215-
Err(SparkError::CastInvalidValue {
1223+
Err(SparkError::InvalidInputInCastToDatetime {
12161224
value: date_str.to_string(),
12171225
from_type: "STRING".to_string(),
12181226
to_type: "DATE".to_string(),
@@ -1341,7 +1349,8 @@ mod tests {
13411349
TimestampMicrosecondType,
13421350
timestamp_parser,
13431351
tz
1344-
);
1352+
)
1353+
.unwrap();
13451354

13461355
assert_eq!(
13471356
result.data_type(),
@@ -1350,6 +1359,33 @@ mod tests {
13501359
assert_eq!(result.len(), 4);
13511360
}
13521361

1362+
#[test]
1363+
fn test_cast_string_to_timestamp_ansi_error() {
1364+
// In ANSI mode, an invalid timestamp string must produce an error rather than null.
1365+
let array: ArrayRef = Arc::new(StringArray::from(vec![
1366+
Some("2020-01-01T12:34:56.123456"),
1367+
Some("not_a_timestamp"),
1368+
]));
1369+
let tz = &timezone::Tz::from_str("UTC").unwrap();
1370+
let string_array = array
1371+
.as_any()
1372+
.downcast_ref::<GenericStringArray<i32>>()
1373+
.expect("Expected a string array");
1374+
1375+
let eval_mode = EvalMode::Ansi;
1376+
let result = cast_utf8_to_timestamp!(
1377+
&string_array,
1378+
eval_mode,
1379+
TimestampMicrosecondType,
1380+
timestamp_parser,
1381+
tz
1382+
);
1383+
assert!(
1384+
result.is_err(),
1385+
"ANSI mode should return Err for an invalid timestamp string"
1386+
);
1387+
}
1388+
13531389
#[test]
13541390
fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
13551391
// prepare input data

native/spark-expr/src/error.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ pub enum SparkError {
3232
to_type: String,
3333
},
3434

35+
/// Like CastInvalidValue but maps to SparkDateTimeException instead of SparkNumberFormatException.
36+
/// Used for string → timestamp/date cast failures where Spark throws SparkDateTimeException
37+
/// with the CAST_INVALID_INPUT error class.
38+
#[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
39+
because it is malformed. Correct the value as per the syntax, or change its target type. \
40+
Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \
41+
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
42+
InvalidInputInCastToDatetime {
43+
value: String,
44+
from_type: String,
45+
to_type: String,
46+
},
47+
3548
#[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
3649
NumericValueOutOfRange {
3750
value: String,
@@ -199,6 +212,7 @@ impl SparkError {
199212
fn error_type_name(&self) -> &'static str {
200213
match self {
201214
SparkError::CastInvalidValue { .. } => "CastInvalidValue",
215+
SparkError::InvalidInputInCastToDatetime { .. } => "InvalidInputInCastToDatetime",
202216
SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange",
203217
SparkError::NumericOutOfRange { .. } => "NumericOutOfRange",
204218
SparkError::CastOverFlow { .. } => "CastOverFlow",
@@ -248,6 +262,11 @@ impl SparkError {
248262
value,
249263
from_type,
250264
to_type,
265+
}
266+
| SparkError::InvalidInputInCastToDatetime {
267+
value,
268+
from_type,
269+
to_type,
251270
} => {
252271
serde_json::json!({
253272
"value": value,
@@ -456,9 +475,12 @@ impl SparkError {
456475
// CastOverflow gets special handling with CastOverflowException
457476
SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException",
458477

459-
// NumberFormatException (for cast invalid input errors)
478+
// NumberFormatException (for cast invalid input errors on numeric types)
460479
SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException",
461480

481+
// DateTimeException (for cast invalid input errors on datetime types)
482+
SparkError::InvalidInputInCastToDatetime { .. } => "org/apache/spark/SparkDateTimeException",
483+
462484
// ArrayIndexOutOfBoundsException
463485
SparkError::InvalidArrayIndex { .. }
464486
| SparkError::InvalidElementAtIndex { .. }
@@ -497,6 +519,7 @@ impl SparkError {
497519
match self {
498520
// Cast errors
499521
SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"),
522+
SparkError::InvalidInputInCastToDatetime { .. } => Some("CAST_INVALID_INPUT"),
500523
SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"),
501524
SparkError::NumericValueOutOfRange { .. } => {
502525
Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION")

spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ object SparkErrorConverter extends ShimSparkErrorConverter {
100100
case None => Array.empty[QueryContext] // No context
101101
}
102102

103-
val summary: String = errorJson.summary.orNull
103+
val summary: String = errorJson.summary.getOrElse("")
104104

105105
// Delegate to version-specific shim - let conversion exceptions propagate
106106
val optEx = convertErrorType(errorJson.errorType, errorClass, params, sparkContext, summary)

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
217217
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
218218
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
219219
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
220-
case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
221-
Incompatible(Some("ANSI mode not supported"))
222220
case DataTypes.TimestampType =>
223-
// https://github.com/apache/datafusion-comet/issues/328
224221
Incompatible(Some("Not all valid formats are supported"))
225222
case _ =>
226223
unsupported(DataTypes.StringType, toType)

spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.spark.sql.comet.shims
2121

22-
import org.apache.spark.{QueryContext, SparkException}
22+
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
2323
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
2424
import org.apache.spark.sql.errors.QueryExecutionErrors
2525
import org.apache.spark.sql.types._
@@ -164,6 +164,22 @@ trait ShimSparkErrorConverter {
164164
QueryExecutionErrors
165165
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
166166

167+
case "InvalidInputInCastToDatetime" =>
168+
val expression =
169+
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
170+
val sourceType = s""""${params("fromType").toString}""""
171+
val targetType = s""""${params("toType").toString}""""
172+
Some(
173+
new SparkDateTimeException(
174+
errorClass = "CAST_INVALID_INPUT",
175+
messageParameters = Map(
176+
"expression" -> expression,
177+
"sourceType" -> sourceType,
178+
"targetType" -> targetType,
179+
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
180+
context = context,
181+
summary = summary))
182+
167183
case "CastOverFlow" =>
168184
val fromType = getDataType(params("fromType").toString)
169185
val toType = getDataType(params("toType").toString)

spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.spark.sql.comet.shims
2121

22-
import org.apache.spark.{QueryContext, SparkException}
22+
import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
2323
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
2424
import org.apache.spark.sql.errors.QueryExecutionErrors
2525
import org.apache.spark.sql.types._
@@ -162,6 +162,22 @@ trait ShimSparkErrorConverter {
162162
QueryExecutionErrors
163163
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
164164

165+
case "InvalidInputInCastToDatetime" =>
166+
val expression =
167+
s"'${params("value").toString.replace("\\", "\\\\").replace("'", "\\'")}'"
168+
val sourceType = s""""${params("fromType").toString}""""
169+
val targetType = s""""${params("toType").toString}""""
170+
Some(
171+
new SparkDateTimeException(
172+
errorClass = "CAST_INVALID_INPUT",
173+
messageParameters = Map(
174+
"expression" -> expression,
175+
"sourceType" -> sourceType,
176+
"targetType" -> targetType,
177+
"ansiConfig" -> "\"spark.sql.ansi.enabled\""),
178+
context = context,
179+
summary = summary))
180+
165181
case "CastOverFlow" =>
166182
val fromType = getDataType(params("fromType").toString)
167183
val toType = getDataType(params("toType").toString)

spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ trait ShimSparkErrorConverter {
175175
QueryExecutionErrors
176176
.invalidInputInCastToNumberError(targetType, str, context.headOption.orNull))
177177

178+
case "InvalidInputInCastToDatetime" =>
179+
val str = UTF8String.fromString(params("value").toString)
180+
val targetType = getDataType(params("toType").toString)
181+
Some(
182+
QueryExecutionErrors
183+
.invalidInputInCastToDatetimeError(str, targetType, context.headOption.orNull))
184+
178185
case "CastOverFlow" =>
179186
val fromType = getDataType(params("fromType").toString)
180187
val toType = getDataType(params("toType").toString)

0 commit comments

Comments
 (0)