Skip to content

Commit 7eb85f3

Browse files
authored
Merge pull request #24 from DataDog/monjalet/resource-exhausted-fix
Cherry pick fixes for min/max on arrays and resource exhaustion
2 parents 5a3ffb6 + 04abbaf commit 7eb85f3

8 files changed

Lines changed: 1214 additions & 58 deletions

File tree

datafusion/common/src/scalar/mod.rs

Lines changed: 197 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ impl PartialOrd for ScalarValue {
506506
}
507507
(List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None,
508508
(Struct(struct_arr1), Struct(struct_arr2)) => {
509-
partial_cmp_struct(struct_arr1, struct_arr2)
509+
partial_cmp_struct(struct_arr1.as_ref(), struct_arr2.as_ref())
510510
}
511511
(Struct(_), _) => None,
512512
(Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2),
@@ -597,10 +597,28 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option<Ordering> {
597597
let arr1 = first_array_for_list(arr1);
598598
let arr2 = first_array_for_list(arr2);
599599

600-
let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
601-
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;
600+
let min_length = arr1.len().min(arr2.len());
601+
let arr1_trimmed = arr1.slice(0, min_length);
602+
let arr2_trimmed = arr2.slice(0, min_length);
603+
604+
let lt_res = arrow::compute::kernels::cmp::lt(&arr1_trimmed, &arr2_trimmed).ok()?;
605+
let eq_res = arrow::compute::kernels::cmp::eq(&arr1_trimmed, &arr2_trimmed).ok()?;
602606

603607
for j in 0..lt_res.len() {
608+
// In Postgres, NULL values in lists are always considered to be greater than non-NULL values:
609+
//
610+
// $ SELECT ARRAY[NULL]::integer[] > ARRAY[1]
611+
// true
612+
//
613+
// These next two if statements are introduced for replicating Postgres behavior, as
614+
// arrow::compute does not account for this.
615+
if arr1_trimmed.is_null(j) && !arr2_trimmed.is_null(j) {
616+
return Some(Ordering::Greater);
617+
}
618+
if !arr1_trimmed.is_null(j) && arr2_trimmed.is_null(j) {
619+
return Some(Ordering::Less);
620+
}
621+
604622
if lt_res.is_valid(j) && lt_res.value(j) {
605623
return Some(Ordering::Less);
606624
}
@@ -609,10 +627,23 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option<Ordering> {
609627
}
610628
}
611629

612-
Some(Ordering::Equal)
630+
Some(arr1.len().cmp(&arr2.len()))
631+
}
632+
633+
fn flatten<'a>(array: &'a StructArray, columns: &mut Vec<&'a ArrayRef>) {
634+
for i in 0..array.num_columns() {
635+
let column = array.column(i);
636+
if let Some(nested_struct) = column.as_any().downcast_ref::<StructArray>() {
637+
// If it's a nested struct, recursively expand
638+
flatten(nested_struct, columns);
639+
} else {
640+
// If it's a primitive type, add directly
641+
columns.push(column);
642+
}
643+
}
613644
}
614645

615-
fn partial_cmp_struct(s1: &Arc<StructArray>, s2: &Arc<StructArray>) -> Option<Ordering> {
646+
pub fn partial_cmp_struct(s1: &StructArray, s2: &StructArray) -> Option<Ordering> {
616647
if s1.len() != s2.len() {
617648
return None;
618649
}
@@ -621,9 +652,15 @@ fn partial_cmp_struct(s1: &Arc<StructArray>, s2: &Arc<StructArray>) -> Option<Or
621652
return None;
622653
}
623654

624-
for col_index in 0..s1.num_columns() {
625-
let arr1 = s1.column(col_index);
626-
let arr2 = s2.column(col_index);
655+
let mut expanded_columns1 = Vec::with_capacity(s1.num_columns());
656+
let mut expanded_columns2 = Vec::with_capacity(s2.num_columns());
657+
658+
flatten(s1, &mut expanded_columns1);
659+
flatten(s2, &mut expanded_columns2);
660+
661+
for col_index in 0..expanded_columns1.len() {
662+
let arr1 = expanded_columns1[col_index];
663+
let arr2 = expanded_columns2[col_index];
627664

628665
let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?;
629666
let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?;
@@ -3420,6 +3457,89 @@ impl ScalarValue {
34203457
.map(|sv| sv.size() - size_of_val(sv))
34213458
.sum::<usize>()
34223459
}
3460+
3461+
/// Compacts the allocation referenced by `self` to the minimum, copying the data if
3462+
/// necessary.
3463+
///
3464+
/// This can be relevant when `self` is a list or contains a list as a nested value, as
3465+
/// a single list holds an Arc to its entire original array buffer.
3466+
pub fn compact(&mut self) {
3467+
match self {
3468+
ScalarValue::Null
3469+
| ScalarValue::Boolean(_)
3470+
| ScalarValue::Float16(_)
3471+
| ScalarValue::Float32(_)
3472+
| ScalarValue::Float64(_)
3473+
| ScalarValue::Decimal128(_, _, _)
3474+
| ScalarValue::Decimal256(_, _, _)
3475+
| ScalarValue::Int8(_)
3476+
| ScalarValue::Int16(_)
3477+
| ScalarValue::Int32(_)
3478+
| ScalarValue::Int64(_)
3479+
| ScalarValue::UInt8(_)
3480+
| ScalarValue::UInt16(_)
3481+
| ScalarValue::UInt32(_)
3482+
| ScalarValue::UInt64(_)
3483+
| ScalarValue::Date32(_)
3484+
| ScalarValue::Date64(_)
3485+
| ScalarValue::Time32Second(_)
3486+
| ScalarValue::Time32Millisecond(_)
3487+
| ScalarValue::Time64Microsecond(_)
3488+
| ScalarValue::Time64Nanosecond(_)
3489+
| ScalarValue::IntervalYearMonth(_)
3490+
| ScalarValue::IntervalDayTime(_)
3491+
| ScalarValue::IntervalMonthDayNano(_)
3492+
| ScalarValue::DurationSecond(_)
3493+
| ScalarValue::DurationMillisecond(_)
3494+
| ScalarValue::DurationMicrosecond(_)
3495+
| ScalarValue::DurationNanosecond(_)
3496+
| ScalarValue::Utf8(_)
3497+
| ScalarValue::LargeUtf8(_)
3498+
| ScalarValue::Utf8View(_)
3499+
| ScalarValue::TimestampSecond(_, _)
3500+
| ScalarValue::TimestampMillisecond(_, _)
3501+
| ScalarValue::TimestampMicrosecond(_, _)
3502+
| ScalarValue::TimestampNanosecond(_, _)
3503+
| ScalarValue::Binary(_)
3504+
| ScalarValue::FixedSizeBinary(_, _)
3505+
| ScalarValue::LargeBinary(_)
3506+
| ScalarValue::BinaryView(_) => (),
3507+
ScalarValue::FixedSizeList(arr) => {
3508+
let array = copy_array_data(&arr.to_data());
3509+
*Arc::make_mut(arr) = FixedSizeListArray::from(array);
3510+
}
3511+
ScalarValue::List(arr) => {
3512+
let array = copy_array_data(&arr.to_data());
3513+
*Arc::make_mut(arr) = ListArray::from(array);
3514+
}
3515+
ScalarValue::LargeList(arr) => {
3516+
let array = copy_array_data(&arr.to_data());
3517+
*Arc::make_mut(arr) = LargeListArray::from(array)
3518+
}
3519+
ScalarValue::Struct(arr) => {
3520+
let array = copy_array_data(&arr.to_data());
3521+
*Arc::make_mut(arr) = StructArray::from(array);
3522+
}
3523+
ScalarValue::Map(arr) => {
3524+
let array = copy_array_data(&arr.to_data());
3525+
*Arc::make_mut(arr) = MapArray::from(array);
3526+
}
3527+
ScalarValue::Union(val, _, _) => {
3528+
if let Some((_, value)) = val.as_mut() {
3529+
value.compact();
3530+
}
3531+
}
3532+
ScalarValue::Dictionary(_, value) => {
3533+
value.compact();
3534+
}
3535+
}
3536+
}
3537+
}
3538+
3539+
pub fn copy_array_data(data: &ArrayData) -> ArrayData {
3540+
let mut copy = MutableArrayData::new(vec![&data], true, data.len());
3541+
copy.extend(0, 0, data.len());
3542+
copy.freeze()
34233543
}
34243544

34253545
macro_rules! impl_scalar {
@@ -4761,6 +4881,75 @@ mod tests {
47614881
])]),
47624882
));
47634883
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
4884+
4885+
let a =
4886+
ScalarValue::List(Arc::new(
4887+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4888+
Some(1),
4889+
Some(2),
4890+
Some(3),
4891+
])]),
4892+
));
4893+
let b =
4894+
ScalarValue::List(Arc::new(
4895+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4896+
Some(2),
4897+
Some(3),
4898+
])]),
4899+
));
4900+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
4901+
4902+
let a =
4903+
ScalarValue::List(Arc::new(
4904+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4905+
Some(2),
4906+
Some(3),
4907+
Some(4),
4908+
])]),
4909+
));
4910+
let b =
4911+
ScalarValue::List(Arc::new(
4912+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4913+
Some(1),
4914+
Some(2),
4915+
])]),
4916+
));
4917+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
4918+
4919+
let a =
4920+
ScalarValue::List(Arc::new(
4921+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4922+
Some(1),
4923+
Some(2),
4924+
Some(3),
4925+
])]),
4926+
));
4927+
let b =
4928+
ScalarValue::List(Arc::new(
4929+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4930+
Some(1),
4931+
Some(2),
4932+
])]),
4933+
));
4934+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
4935+
4936+
let a =
4937+
ScalarValue::List(Arc::new(
4938+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4939+
None,
4940+
Some(2),
4941+
Some(3),
4942+
])]),
4943+
));
4944+
let b =
4945+
ScalarValue::List(Arc::new(
4946+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4947+
Some(1),
4948+
Some(2),
4949+
Some(3),
4950+
])]),
4951+
));
4952+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
47644953
}
47654954

47664955
#[test]

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use arrow::array::{
2121
Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray,
2222
BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray,
23-
StringViewArray,
23+
StringViewArray, StructArray,
2424
};
2525
use arrow::buffer::NullBuffer;
2626
use arrow::datatypes::DataType;
@@ -193,6 +193,18 @@ pub fn set_nulls_dyn(input: &dyn Array, nulls: Option<NullBuffer>) -> Result<Arr
193193
))
194194
}
195195
}
196+
DataType::Struct(_) => {
197+
let input = input.as_struct();
198+
// safety: values / offsets came from a valid struct array
199+
// and we checked nulls has the same length as values
200+
unsafe {
201+
Arc::new(StructArray::new_unchecked(
202+
input.fields().clone(),
203+
input.columns().to_vec(),
204+
nulls,
205+
))
206+
}
207+
}
196208
_ => {
197209
return not_impl_err!("Applying nulls {:?}", input.data_type());
198210
}

0 commit comments

Comments
 (0)