Skip to content

Commit 47186ca

Browse files
committed
fix: array to array cast
1 parent 28bd4bc commit 47186ca

3 files changed

Lines changed: 226 additions & 29 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ use arrow::array::builder::StringBuilder;
4343
use arrow::array::{
4444
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
4545
};
46-
use arrow::compute::can_cast_types;
4746
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
4847
use arrow::datatypes::{Field, Fields, GenericBinaryType};
4948
use arrow::error::ArrowError;
@@ -294,6 +293,9 @@ pub(crate) fn cast_array(
294293
};
295294

296295
let cast_result = match (&from_type, to_type) {
296+
// Null arrays carry no concrete values, so Arrow's native cast can change only the
297+
// logical type while preserving length and nullness.
298+
(Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?),
297299
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
298300
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
299301
(Utf8, Timestamp(_, _)) => {
@@ -366,8 +368,19 @@ pub(crate) fn cast_array(
366368
cast_options,
367369
)?),
368370
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
369-
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
370-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
371+
(List(_), List(to)) => {
372+
let list_array = array.as_list::<i32>();
373+
let casted_values = cast_array(
374+
Arc::clone(list_array.values()),
375+
to.data_type(),
376+
cast_options,
377+
)?;
378+
Ok(Arc::new(ListArray::new(
379+
Arc::clone(to),
380+
list_array.offsets().clone(),
381+
casted_values,
382+
list_array.nulls().cloned(),
383+
)) as ArrayRef)
371384
}
372385
(Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?),
373386
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -803,7 +816,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
803816
#[cfg(test)]
804817
mod tests {
805818
use super::*;
806-
use arrow::array::StringArray;
819+
use arrow::array::{ListArray, NullArray, StringArray};
820+
use arrow::buffer::OffsetBuffer;
807821
use arrow::datatypes::TimestampMicrosecondType;
808822
use arrow::datatypes::{Field, Fields};
809823
#[test]
@@ -929,8 +943,6 @@ mod tests {
929943

930944
#[test]
931945
fn test_cast_string_array_to_string() {
932-
use arrow::array::ListArray;
933-
use arrow::buffer::OffsetBuffer;
934946
let values_array =
935947
StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]);
936948
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
@@ -955,8 +967,6 @@ mod tests {
955967

956968
#[test]
957969
fn test_cast_i32_array_to_string() {
958-
use arrow::array::ListArray;
959-
use arrow::buffer::OffsetBuffer;
960970
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
961971
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
962972
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
@@ -977,4 +987,33 @@ mod tests {
977987
assert_eq!(r#"[null]"#, string_array.value(2));
978988
assert_eq!(r#"[]"#, string_array.value(3));
979989
}
990+
991+
#[test]
992+
fn test_cast_array_of_nulls_to_array() {
993+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 2, 3, 3].into());
994+
let from_item_field = Arc::new(Field::new("item", DataType::Null, true));
995+
let from_array: ArrayRef = Arc::new(ListArray::new(
996+
from_item_field,
997+
offsets_buffer,
998+
Arc::new(NullArray::new(3)),
999+
None,
1000+
));
1001+
1002+
let to_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1003+
let to_array = cast_array(
1004+
from_array,
1005+
&to_type,
1006+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
1007+
)
1008+
.unwrap();
1009+
1010+
let result = to_array.as_list::<i32>();
1011+
assert_eq!(3, result.len());
1012+
assert_eq!(result.value_offsets(), &[0, 2, 3, 3]);
1013+
1014+
let values = result.values().as_primitive::<Int32Type>();
1015+
assert_eq!(3, values.len());
1016+
assert_eq!(3, values.null_count());
1017+
assert!(values.iter().all(|value| value.is_none()));
1018+
}
9801019
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
141141

142142
(fromType, toType) match {
143143
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
144+
case (ArrayType(DataTypes.DateType, _), ArrayType(toElementType, _))
145+
if toElementType != DataTypes.StringType &&
146+
toElementType != DataTypes.TimestampType &&
147+
toElementType != DataTypes.DateType =>
148+
unsupported(fromType, toType)
144149
case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType =>
145150
Incompatible()
146151
case (dt: ArrayType, DataTypes.StringType) =>

0 commit comments

Comments
 (0)