Skip to content

Commit efe0491

Browse files
authored
improve union behaviour with null values (#27)
1 parent a645a62 commit efe0491

File tree

3 files changed

+132
-29
lines changed

3 files changed

+132
-29
lines changed

src/common.rs

+3-6
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,8 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
7777
Some(ColumnarValue::Array(a)) => {
7878
if args.len() > 2 {
7979
// TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
80-
return exec_err!(
81-
"More than 1 path element is not supported when querying JSON using an array."
82-
);
83-
}
84-
if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
80+
exec_err!("More than 1 path element is not supported when querying JSON using an array.")
81+
} else if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
8582
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
8683
zip_apply(json_array, paths, jiter_find, true)
8784
} else if let Some(str_path_array) = a.as_any().downcast_ref::<LargeStringArray>() {
@@ -94,7 +91,7 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
9491
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
9592
zip_apply(json_array, paths, jiter_find, false)
9693
} else {
97-
return exec_err!("unexpected second argument type, expected string or int array");
94+
exec_err!("unexpected second argument type, expected string or int array")
9895
}
9996
}
10097
Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find),

src/common_union.rs

+16-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::{Arc, OnceLock};
22

3-
use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, UnionArray};
3+
use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray};
44
use arrow::buffer::Buffer;
55
use arrow_schema::{DataType, Field, UnionFields, UnionMode};
66
use datafusion_common::ScalarValue;
@@ -42,7 +42,6 @@ pub(crate) fn json_from_union_scalar<'a>(
4242

4343
#[derive(Debug)]
4444
pub(crate) struct JsonUnion {
45-
nulls: Vec<Option<bool>>,
4645
bools: Vec<Option<bool>>,
4746
ints: Vec<Option<i64>>,
4847
floats: Vec<Option<f64>>,
@@ -51,22 +50,21 @@ pub(crate) struct JsonUnion {
5150
objects: Vec<Option<String>>,
5251
type_ids: Vec<i8>,
5352
index: usize,
54-
capacity: usize,
53+
length: usize,
5554
}
5655

5756
impl JsonUnion {
58-
fn new(capacity: usize) -> Self {
57+
fn new(length: usize) -> Self {
5958
Self {
60-
nulls: vec![None; capacity],
61-
bools: vec![None; capacity],
62-
ints: vec![None; capacity],
63-
floats: vec![None; capacity],
64-
strings: vec![None; capacity],
65-
arrays: vec![None; capacity],
66-
objects: vec![None; capacity],
67-
type_ids: vec![0; capacity],
59+
bools: vec![None; length],
60+
ints: vec![None; length],
61+
floats: vec![None; length],
62+
strings: vec![None; length],
63+
arrays: vec![None; length],
64+
objects: vec![None; length],
65+
type_ids: vec![0; length],
6866
index: 0,
69-
capacity,
67+
length,
7068
}
7169
}
7270

@@ -77,7 +75,7 @@ impl JsonUnion {
7775
fn push(&mut self, field: JsonUnionField) {
7876
self.type_ids[self.index] = field.type_id();
7977
match field {
80-
JsonUnionField::JsonNull => self.nulls[self.index] = Some(true),
78+
JsonUnionField::JsonNull => (),
8179
JsonUnionField::Bool(value) => self.bools[self.index] = Some(value),
8280
JsonUnionField::Int(value) => self.ints[self.index] = Some(value),
8381
JsonUnionField::Float(value) => self.floats[self.index] = Some(value),
@@ -86,13 +84,12 @@ impl JsonUnion {
8684
JsonUnionField::Object(value) => self.objects[self.index] = Some(value),
8785
}
8886
self.index += 1;
89-
debug_assert!(self.index <= self.capacity);
87+
debug_assert!(self.index <= self.length);
9088
}
9189

9290
fn push_none(&mut self) {
93-
self.type_ids[self.index] = TYPE_ID_NULL;
9491
self.index += 1;
95-
debug_assert!(self.index <= self.capacity);
92+
debug_assert!(self.index <= self.length);
9693
}
9794
}
9895

@@ -119,7 +116,7 @@ impl TryFrom<JsonUnion> for UnionArray {
119116

120117
fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
121118
let children: Vec<Arc<dyn Array>> = vec![
122-
Arc::new(BooleanArray::from(value.nulls)),
119+
Arc::new(NullArray::new(value.length)),
123120
Arc::new(BooleanArray::from(value.bools)),
124121
Arc::new(Int64Array::from(value.ints)),
125122
Arc::new(Float64Array::from(value.floats)),
@@ -155,7 +152,7 @@ fn union_fields() -> UnionFields {
155152
FIELDS
156153
.get_or_init(|| {
157154
UnionFields::from_iter([
158-
(TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Boolean, true))),
155+
(TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))),
159156
(TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))),
160157
(TYPE_ID_INT, Arc::new(Field::new("int", DataType::Int64, false))),
161158
(TYPE_ID_FLOAT, Arc::new(Field::new("float", DataType::Float64, false))),

tests/main.rs

+113-4
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async fn test_json_get_union() {
6868
"| object_foo | {str=abc} |",
6969
"| object_foo_array | {array=[1]} |",
7070
"| object_foo_obj | {object={}} |",
71-
"| object_foo_null | {null=true} |",
71+
"| object_foo_null | {null=} |",
7272
"| object_bar | {null=} |",
7373
"| list_foo | {null=} |",
7474
"| invalid_json | {null=} |",
@@ -675,7 +675,7 @@ async fn test_json_get_union_array_nested() {
675675
"+-------------+",
676676
"| {array=[0]} |",
677677
"| {null=} |",
678-
"| {null=true} |",
678+
"| {null=} |",
679679
"+-------------+",
680680
];
681681

@@ -725,7 +725,7 @@ async fn test_arrow() {
725725
"| object_foo | {str=abc} |",
726726
"| object_foo_array | {array=[1]} |",
727727
"| object_foo_obj | {object={}} |",
728-
"| object_foo_null | {null=true} |",
728+
"| object_foo_null | {null=} |",
729729
"| object_bar | {null=} |",
730730
"| list_foo | {null=} |",
731731
"| invalid_json | {null=} |",
@@ -903,7 +903,7 @@ async fn test_arrow_nested_columns() {
903903
"+-------------+",
904904
"| {array=[0]} |",
905905
"| {null=} |",
906-
"| {null=true} |",
906+
"| {null=} |",
907907
"+-------------+",
908908
];
909909

@@ -990,3 +990,112 @@ async fn test_question_filter() {
990990
];
991991
assert_batches_eq!(expected, &batches);
992992
}
993+
994+
#[tokio::test]
995+
async fn test_json_get_union_is_null() {
996+
let batches = run_query("select name, json_get(json_data, 'foo') is null from test")
997+
.await
998+
.unwrap();
999+
1000+
let expected = [
1001+
"+------------------+----------------------------------------------+",
1002+
"| name | json_get(test.json_data,Utf8(\"foo\")) IS NULL |",
1003+
"+------------------+----------------------------------------------+",
1004+
"| object_foo | false |",
1005+
"| object_foo_array | false |",
1006+
"| object_foo_obj | false |",
1007+
"| object_foo_null | true |",
1008+
"| object_bar | true |",
1009+
"| list_foo | true |",
1010+
"| invalid_json | true |",
1011+
"+------------------+----------------------------------------------+",
1012+
];
1013+
assert_batches_eq!(expected, &batches);
1014+
}
1015+
1016+
#[tokio::test]
1017+
async fn test_json_get_union_is_not_null() {
1018+
let batches = run_query("select name, json_get(json_data, 'foo') is not null from test")
1019+
.await
1020+
.unwrap();
1021+
1022+
let expected = [
1023+
"+------------------+--------------------------------------------------+",
1024+
"| name | json_get(test.json_data,Utf8(\"foo\")) IS NOT NULL |",
1025+
"+------------------+--------------------------------------------------+",
1026+
"| object_foo | true |",
1027+
"| object_foo_array | true |",
1028+
"| object_foo_obj | true |",
1029+
"| object_foo_null | false |",
1030+
"| object_bar | false |",
1031+
"| list_foo | false |",
1032+
"| invalid_json | false |",
1033+
"+------------------+--------------------------------------------------+",
1034+
];
1035+
assert_batches_eq!(expected, &batches);
1036+
}
1037+
1038+
#[tokio::test]
1039+
async fn test_arrow_union_is_null() {
1040+
let batches = run_query("select name, (json_data->'foo') is null from test")
1041+
.await
1042+
.unwrap();
1043+
1044+
let expected = [
1045+
"+------------------+----------------------------------+",
1046+
"| name | json_data -> Utf8(\"foo\") IS NULL |",
1047+
"+------------------+----------------------------------+",
1048+
"| object_foo | false |",
1049+
"| object_foo_array | false |",
1050+
"| object_foo_obj | false |",
1051+
"| object_foo_null | true |",
1052+
"| object_bar | true |",
1053+
"| list_foo | true |",
1054+
"| invalid_json | true |",
1055+
"+------------------+----------------------------------+",
1056+
];
1057+
assert_batches_eq!(expected, &batches);
1058+
}
1059+
1060+
#[tokio::test]
1061+
async fn test_arrow_union_is_not_null() {
1062+
let batches = run_query("select name, (json_data->'foo') is not null from test")
1063+
.await
1064+
.unwrap();
1065+
1066+
let expected = [
1067+
"+------------------+--------------------------------------+",
1068+
"| name | json_data -> Utf8(\"foo\") IS NOT NULL |",
1069+
"+------------------+--------------------------------------+",
1070+
"| object_foo | true |",
1071+
"| object_foo_array | true |",
1072+
"| object_foo_obj | true |",
1073+
"| object_foo_null | false |",
1074+
"| object_bar | false |",
1075+
"| list_foo | false |",
1076+
"| invalid_json | false |",
1077+
"+------------------+--------------------------------------+",
1078+
];
1079+
assert_batches_eq!(expected, &batches);
1080+
}
1081+
1082+
#[tokio::test]
1083+
async fn test_arrow_scalar_union_is_null() {
1084+
let batches = run_query(
1085+
r#"
1086+
select ('{"x": 1}'->'foo') is null as not_contains,
1087+
('{"foo": 1}'->'foo') is null as contains_num,
1088+
('{"foo": null}'->'foo') is null as contains_null"#,
1089+
)
1090+
.await
1091+
.unwrap();
1092+
1093+
let expected = [
1094+
"+--------------+--------------+---------------+",
1095+
"| not_contains | contains_num | contains_null |",
1096+
"+--------------+--------------+---------------+",
1097+
"| true | false | true |",
1098+
"+--------------+--------------+---------------+",
1099+
];
1100+
assert_batches_eq!(expected, &batches);
1101+
}

0 commit comments

Comments
 (0)