Skip to content

Commit cb0e669

Browse files
authored
fix case of f(dict_array, dict_array) invocation (#64)
1 parent 38caf97 commit cb0e669

File tree

3 files changed

+103
-47
lines changed

3 files changed

+103
-47
lines changed

src/common.rs

+52-8
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@ use std::sync::Arc;
33

44
use datafusion::arrow::array::{
55
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
6-
StringArray, StringViewArray, UInt64Array, UnionArray,
6+
StringArray, StringViewArray, UInt64Array,
77
};
88
use datafusion::arrow::compute::take;
99
use datafusion::arrow::datatypes::{
10-
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type,
10+
ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type,
1111
};
1212
use datafusion::arrow::downcast_dictionary_array;
1313
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
1414
use datafusion::logical_expr::ColumnarValue;
1515
use jiter::{Jiter, JiterError, Peek};
1616

17-
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};
17+
use crate::common_union::{
18+
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
19+
};
1820

1921
/// General implementation of `ScalarUDFImpl::return_type`.
2022
///
@@ -95,6 +97,7 @@ impl From<i64> for JsonPath<'_> {
9597
}
9698
}
9799

100+
#[derive(Debug)]
98101
enum JsonPathArgs<'a> {
99102
Array(&'a ArrayRef),
100103
Scalars(Vec<JsonPath<'a>>),
@@ -175,9 +178,48 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
175178
) -> DataFusionResult<ArrayRef> {
176179
downcast_dictionary_array!(
177180
json_array => {
178-
let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?;
179-
post_process_dict(json_array, values, return_dict)
180-
}
181+
fn wrap_as_dictionary<K: ArrowDictionaryKeyType>(original: &DictionaryArray<K>, new_values: ArrayRef) -> DictionaryArray<K> {
182+
assert_eq!(original.keys().len(), new_values.len());
183+
let mut key = K::Native::ZERO;
184+
let key_range = std::iter::from_fn(move || {
185+
let next = key;
186+
key = key.add_checked(K::Native::ONE).expect("keys exhausted");
187+
Some(next)
188+
}).take(new_values.len());
189+
let mut keys = PrimitiveArray::<K>::from_iter_values(key_range);
190+
if is_json_union(new_values.data_type()) {
191+
// JSON union: post-process the array to set keys to null where the union member is null
192+
let type_ids = new_values.as_union().type_ids();
193+
keys = mask_dictionary_keys(&keys, type_ids);
194+
}
195+
DictionaryArray::<K>::new(keys, new_values)
196+
}
197+
198+
// TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
199+
// combinations and do less work, but this can be left as a future optimization
200+
let output = match json_array.values().data_type() {
201+
DataType::Utf8 => zip_apply(json_array.downcast_dict::<StringArray>().unwrap(), path_array, to_array, jiter_find),
202+
DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::<LargeStringArray>().unwrap(), path_array, to_array, jiter_find),
203+
DataType::Utf8View => zip_apply(json_array.downcast_dict::<StringViewArray>().unwrap(), path_array, to_array, jiter_find),
204+
other => if let Some(child_array) = nested_json_array_ref(json_array.values(), is_object_lookup_array(path_array.data_type())) {
205+
// Horrible case: dict containing union as input with array for paths, figure
206+
// out from the path type which union members we should access, repack the
207+
// dictionary and then recurse.
208+
//
209+
// Use direct return because if return_dict applies, the recursion will handle it.
210+
return invoke_array_array(&(Arc::new(json_array.with_values(child_array.clone())) as _), path_array, to_array, jiter_find, return_dict)
211+
} else {
212+
exec_err!("unexpected json array type {:?}", other)
213+
}
214+
}?;
215+
216+
if return_dict {
217+
// ensure return is a dictionary to satisfy the declaration above in return_type_check
218+
Ok(Arc::new(wrap_as_dictionary(json_array, output)))
219+
} else {
220+
Ok(output)
221+
}
222+
},
181223
DataType::Utf8 => zip_apply(json_array.as_string::<i32>().iter(), path_array, to_array, jiter_find),
182224
DataType::LargeUtf8 => zip_apply(json_array.as_string::<i64>().iter(), path_array, to_array, jiter_find),
183225
DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find),
@@ -239,6 +281,7 @@ fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
239281
to_array,
240282
jiter_find,
241283
)
284+
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
242285
.map(ColumnarValue::Array)
243286
}
244287

@@ -250,6 +293,7 @@ fn invoke_scalar_scalars<I>(
250293
) -> DataFusionResult<ColumnarValue> {
251294
let s = extract_json_scalar(scalar)?;
252295
let v = jiter_find(s, path).ok();
296+
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
253297
Ok(ColumnarValue::Scalar(to_scalar(v)))
254298
}
255299

@@ -321,7 +365,7 @@ fn post_process_dict<T: ArrowDictionaryKeyType>(
321365
if return_dict {
322366
if is_json_union(result_values.data_type()) {
323367
// JSON union: post-process the array to set keys to null where the union member is null
324-
let type_ids = result_values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
368+
let type_ids = result_values.as_union().type_ids();
325369
Ok(Arc::new(DictionaryArray::new(
326370
mask_dictionary_keys(dict_array.keys(), type_ids),
327371
result_values,
@@ -413,7 +457,7 @@ impl From<Utf8Error> for GetError {
413457
///
414458
/// That said, doing this might also be an optimization for cases like null-checking without needing
415459
/// to check the value union array.
416-
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
460+
fn mask_dictionary_keys<K: ArrowDictionaryKeyType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
417461
let mut null_mask = vec![true; keys.len()];
418462
for (i, k) in keys.iter().enumerate() {
419463
match k {

src/common_union.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ pub fn is_json_union(data_type: &DataType) -> bool {
2222
/// * `object_lookup` - If `true`, extract from the "object" member of the union,
2323
/// otherwise extract from the "array" member
2424
pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> {
25+
nested_json_array_ref(array, object_lookup).map(AsArray::as_string)
26+
}
27+
28+
pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> {
2529
let union_array: &UnionArray = array.as_any().downcast_ref::<UnionArray>()?;
2630
let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY };
27-
union_array.child(type_id).as_any().downcast_ref()
31+
Some(union_array.child(type_id))
2832
}
2933

3034
/// Extract a JSON string from a `JsonUnion` scalar

0 commit comments

Comments
 (0)