Skip to content

Commit b1ca610

Browse files
Don't return dictionary encoded booleans (#42)
Co-authored-by: Adrian Garcia Badaracco <[email protected]>
1 parent 5738ff3 commit b1ca610

13 files changed

+282
-79
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
pass_filenames: false
2121
- id: clippy
2222
name: Clippy
23-
entry: cargo clippy
23+
entry: cargo clippy -- -D warnings
2424
types: [rust]
2525
language: system
2626
pass_filenames: false

src/common.rs

+50-34
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use datafusion::arrow::array::{
55
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
66
StringViewArray, UInt64Array, UnionArray,
77
};
8-
use datafusion::arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType};
8+
use datafusion::arrow::compute::take;
9+
use datafusion::arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType};
910
use datafusion::arrow::downcast_dictionary_array;
1011
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
1112
use datafusion::logical_expr::ColumnarValue;
@@ -21,7 +22,7 @@ use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_arr
2122
/// * `fn_name` - The name of the function
2223
/// * `value_type` - The general return type of the function, might be wrapped in a dictionary depending
2324
/// on the first argument
24-
pub fn scalar_udf_return_type(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult<DataType> {
25+
pub fn return_type_check(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult<DataType> {
2526
let Some(first) = args.first() else {
2627
return plan_err!("The '{fn_name}' function requires one or more arguments.");
2728
};
@@ -105,6 +106,7 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
105106
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
106107
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
107108
to_scalar: impl Fn(Option<I>) -> ScalarValue,
109+
return_dict: bool,
108110
) -> DataFusionResult<ColumnarValue> {
109111
let Some(first_arg) = args.first() else {
110112
// I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe
@@ -118,13 +120,17 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
118120
// TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
119121
exec_err!("More than 1 path element is not supported when querying JSON using an array.")
120122
} else {
121-
invoke_array(json_array, a, to_array, jiter_find)
123+
invoke_array(json_array, a, to_array, jiter_find, return_dict)
122124
}
123125
}
124-
Some(ColumnarValue::Scalar(_)) => {
125-
scalar_apply(json_array, &JsonPath::extract_path(args), to_array, jiter_find)
126-
}
127-
None => scalar_apply(json_array, &[], to_array, jiter_find),
126+
Some(ColumnarValue::Scalar(_)) => scalar_apply(
127+
json_array,
128+
&JsonPath::extract_path(args),
129+
to_array,
130+
jiter_find,
131+
return_dict,
132+
),
133+
None => scalar_apply(json_array, &[], to_array, jiter_find, return_dict),
128134
};
129135
array.map(ColumnarValue::from)
130136
}
@@ -137,24 +143,26 @@ fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
137143
needle_array: &ArrayRef,
138144
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
139145
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
146+
return_dict: bool,
140147
) -> DataFusionResult<ArrayRef> {
141148
if let Some(d) = needle_array.as_any_dictionary_opt() {
142-
invoke_array(json_array, d.values(), to_array, jiter_find)
149+
// this is the (very rare) case where the needle is a dictionary, it shouldn't affect what we return
150+
invoke_array(json_array, d.values(), to_array, jiter_find, return_dict)
143151
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringArray>() {
144152
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
145-
zip_apply(json_array, paths, to_array, jiter_find, true)
153+
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
146154
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<LargeStringArray>() {
147155
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
148-
zip_apply(json_array, paths, to_array, jiter_find, true)
156+
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
149157
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringViewArray>() {
150158
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
151-
zip_apply(json_array, paths, to_array, jiter_find, true)
159+
zip_apply(json_array, paths, to_array, jiter_find, true, return_dict)
152160
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<Int64Array>() {
153161
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
154-
zip_apply(json_array, paths, to_array, jiter_find, false)
162+
zip_apply(json_array, paths, to_array, jiter_find, false, return_dict)
155163
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<UInt64Array>() {
156164
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
157-
zip_apply(json_array, paths, to_array, jiter_find, false)
165+
zip_apply(json_array, paths, to_array, jiter_find, false, return_dict)
158166
} else {
159167
exec_err!("unexpected second argument type, expected string or int array")
160168
}
@@ -166,22 +174,15 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
166174
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
167175
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
168176
object_lookup: bool,
177+
return_dict: bool,
169178
) -> DataFusionResult<ArrayRef> {
170179
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
171180
use datafusion::arrow::datatypes as arrow_schema;
172181

173182
let c = downcast_dictionary_array!(
174183
json_array => {
175-
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup)?;
176-
if !is_json_union(values.data_type()) {
177-
return Ok(Arc::new(json_array.with_values(values)));
178-
}
179-
// JSON union: post-process the array to set keys to null where the union member is null
180-
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
181-
return Ok(Arc::new(DictionaryArray::new(
182-
mask_dictionary_keys(json_array.keys(), type_ids),
183-
values,
184-
)));
184+
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup, false)?;
185+
return post_process_dict(json_array, values, return_dict);
185186
}
186187
DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find),
187188
DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find),
@@ -242,22 +243,15 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
242243
path: &[JsonPath],
243244
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
244245
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
246+
return_dict: bool,
245247
) -> DataFusionResult<ArrayRef> {
246248
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
247249
use datafusion::arrow::datatypes as arrow_schema;
248250

249251
let c = downcast_dictionary_array!(
250252
json_array => {
251-
let values = scalar_apply(json_array.values(), path, to_array, jiter_find)?;
252-
if !is_json_union(values.data_type()) {
253-
return Ok(Arc::new(json_array.with_values(values)));
254-
}
255-
// JSON union: post-process the array to set keys to null where the union member is null
256-
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
257-
return Ok(Arc::new(DictionaryArray::new(
258-
mask_dictionary_keys(json_array.keys(), type_ids),
259-
values,
260-
)));
253+
let values = scalar_apply(json_array.values(), path, to_array, jiter_find, false)?;
254+
return post_process_dict(json_array, values, return_dict);
261255
}
262256
DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find),
263257
DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find),
@@ -268,10 +262,32 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
268262
return exec_err!("unexpected json array type {:?}", other);
269263
}
270264
);
271-
272265
to_array(c)
273266
}
274267

268+
/// Take a dictionary array of JSON data and an array of result values and combine them.
269+
fn post_process_dict<T: ArrowDictionaryKeyType>(
270+
dict_array: &DictionaryArray<T>,
271+
result_values: ArrayRef,
272+
return_dict: bool,
273+
) -> DataFusionResult<ArrayRef> {
274+
if return_dict {
275+
if is_json_union(result_values.data_type()) {
276+
// JSON union: post-process the array to set keys to null where the union member is null
277+
let type_ids = result_values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
278+
Ok(Arc::new(DictionaryArray::new(
279+
mask_dictionary_keys(dict_array.keys(), type_ids),
280+
result_values,
281+
)))
282+
} else {
283+
Ok(Arc::new(dict_array.with_values(result_values)))
284+
}
285+
} else {
286+
// this is what cast would do under the hood to unpack a dictionary into an array of its values
287+
Ok(take(&result_values, dict_array.keys(), None)?)
288+
}
289+
}
290+
275291
fn is_object_lookup(path: &[JsonPath]) -> bool {
276292
if let Some(first) = path.first() {
277293
matches!(first, JsonPath::Key(_))

src/json_as_text.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88
use jiter::Peek;
99

10-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
10+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1111
use crate::common_macros::make_udf_function;
1212

1313
make_udf_function!(
@@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonAsText {
4646
}
4747

4848
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
49-
scalar_udf_return_type(arg_types, self.name(), DataType::Utf8)
49+
return_type_check(arg_types, self.name(), DataType::Utf8)
5050
}
5151

5252
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
@@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonAsText {
5555
jiter_json_as_text,
5656
|c| Ok(Arc::new(c) as ArrayRef),
5757
ScalarValue::Utf8,
58+
true,
5859
)
5960
}
6061

src/json_contains.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use datafusion::common::arrow::array::{ArrayRef, BooleanArray};
66
use datafusion::common::{plan_err, Result, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88

9-
use crate::common::{invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
9+
use crate::common::{invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1010
use crate::common_macros::make_udf_function;
1111

1212
make_udf_function!(
@@ -48,7 +48,7 @@ impl ScalarUDFImpl for JsonContains {
4848
if arg_types.len() < 2 {
4949
plan_err!("The 'json_contains' function requires two or more arguments.")
5050
} else {
51-
scalar_udf_return_type(arg_types, self.name(), DataType::Boolean)
51+
return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean)
5252
}
5353
}
5454

@@ -58,6 +58,7 @@ impl ScalarUDFImpl for JsonContains {
5858
jiter_json_contains,
5959
|c| Ok(Arc::new(c) as ArrayRef),
6060
ScalarValue::Boolean,
61+
false,
6162
)
6263
}
6364

src/json_get.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use datafusion::common::Result as DataFusionResult;
88
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
99
use jiter::{Jiter, NumberAny, NumberInt, Peek};
1010

11-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
11+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1212
use crate::common_macros::make_udf_function;
1313
use crate::common_union::{JsonUnion, JsonUnionField};
1414

@@ -50,15 +50,15 @@ impl ScalarUDFImpl for JsonGet {
5050
}
5151

5252
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
53-
scalar_udf_return_type(arg_types, self.name(), JsonUnion::data_type())
53+
return_type_check(arg_types, self.name(), JsonUnion::data_type())
5454
}
5555

5656
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
5757
let to_array = |c: JsonUnion| {
5858
let array: UnionArray = c.try_into()?;
5959
Ok(Arc::new(array) as ArrayRef)
6060
};
61-
invoke::<JsonUnion, JsonUnionField>(args, jiter_json_get_union, to_array, JsonUnionField::scalar_value)
61+
invoke::<JsonUnion, JsonUnionField>(args, jiter_json_get_union, to_array, JsonUnionField::scalar_value, true)
6262
}
6363

6464
fn aliases(&self) -> &[String] {

src/json_get_bool.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88
use jiter::Peek;
99

10-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
10+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1111
use crate::common_macros::make_udf_function;
1212

1313
make_udf_function!(
@@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetBool {
4646
}
4747

4848
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
49-
scalar_udf_return_type(arg_types, self.name(), DataType::Boolean)
49+
return_type_check(arg_types, self.name(), DataType::Boolean).map(|_| DataType::Boolean)
5050
}
5151

5252
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
@@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetBool {
5555
jiter_json_get_bool,
5656
|c| Ok(Arc::new(c) as ArrayRef),
5757
ScalarValue::Boolean,
58+
false,
5859
)
5960
}
6061

src/json_get_float.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88
use jiter::{NumberAny, Peek};
99

10-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
10+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1111
use crate::common_macros::make_udf_function;
1212

1313
make_udf_function!(
@@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetFloat {
4646
}
4747

4848
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
49-
scalar_udf_return_type(arg_types, self.name(), DataType::Float64)
49+
return_type_check(arg_types, self.name(), DataType::Float64)
5050
}
5151

5252
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
@@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetFloat {
5555
jiter_json_get_float,
5656
|c| Ok(Arc::new(c) as ArrayRef),
5757
ScalarValue::Float64,
58+
true,
5859
)
5960
}
6061

src/json_get_int.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88
use jiter::{NumberInt, Peek};
99

10-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
10+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1111
use crate::common_macros::make_udf_function;
1212

1313
make_udf_function!(
@@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetInt {
4646
}
4747

4848
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
49-
scalar_udf_return_type(arg_types, self.name(), DataType::Int64)
49+
return_type_check(arg_types, self.name(), DataType::Int64)
5050
}
5151

5252
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
@@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetInt {
5555
jiter_json_get_int,
5656
|c| Ok(Arc::new(c) as ArrayRef),
5757
ScalarValue::Int64,
58+
true,
5859
)
5960
}
6061

src/json_get_json.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use datafusion::arrow::datatypes::DataType;
66
use datafusion::common::{Result as DataFusionResult, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88

9-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
9+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1010
use crate::common_macros::make_udf_function;
1111

1212
make_udf_function!(
@@ -45,7 +45,7 @@ impl ScalarUDFImpl for JsonGetJson {
4545
}
4646

4747
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
48-
scalar_udf_return_type(arg_types, self.name(), DataType::Utf8)
48+
return_type_check(arg_types, self.name(), DataType::Utf8)
4949
}
5050

5151
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
@@ -54,6 +54,7 @@ impl ScalarUDFImpl for JsonGetJson {
5454
jiter_json_get_json,
5555
|c| Ok(Arc::new(c) as ArrayRef),
5656
ScalarValue::Utf8,
57+
true,
5758
)
5859
}
5960

src/json_get_str.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
77
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
88
use jiter::Peek;
99

10-
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
10+
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
1111
use crate::common_macros::make_udf_function;
1212

1313
make_udf_function!(
@@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetStr {
4646
}
4747

4848
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
49-
scalar_udf_return_type(arg_types, self.name(), DataType::Utf8)
49+
return_type_check(arg_types, self.name(), DataType::Utf8)
5050
}
5151

5252
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
@@ -55,6 +55,7 @@ impl ScalarUDFImpl for JsonGetStr {
5555
jiter_json_get_str,
5656
|c| Ok(Arc::new(c) as ArrayRef),
5757
ScalarValue::Utf8,
58+
true,
5859
)
5960
}
6061

0 commit comments

Comments
 (0)