diff --git a/src/common_union.rs b/src/common_union.rs index 93b3ead..74820ff 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -1,7 +1,8 @@ use std::sync::{Arc, OnceLock}; use datafusion::arrow::array::{ - Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray, + Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, ListArray, ListBuilder, NullArray, StringArray, + StringBuilder, UnionArray, }; use datafusion::arrow::buffer::{Buffer, ScalarBuffer}; use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; @@ -49,7 +50,7 @@ pub(crate) struct JsonUnion { ints: Vec>, floats: Vec>, strings: Vec>, - arrays: Vec>, + arrays: Vec>>, objects: Vec>, type_ids: Vec, index: usize, @@ -96,24 +97,6 @@ impl JsonUnion { } } -/// So we can do `collect::()` -impl FromIterator> for JsonUnion { - fn from_iter>>(iter: I) -> Self { - let inner = iter.into_iter(); - let (lower, upper) = inner.size_hint(); - let mut union = Self::new(upper.unwrap_or(lower)); - - for opt_field in inner { - if let Some(union_field) = opt_field { - union.push(union_field); - } else { - union.push_none(); - } - } - union - } -} - impl TryFrom for UnionArray { type Error = ArrowError; @@ -124,13 +107,42 @@ impl TryFrom for UnionArray { Arc::new(Int64Array::from(value.ints)), Arc::new(Float64Array::from(value.floats)), Arc::new(StringArray::from(value.strings)), - Arc::new(StringArray::from(value.arrays)), + Arc::new(StringArray::from( + value + .arrays + .into_iter() + .map(|r| r.map(|e| e.join(","))) + .collect::>(), + )), Arc::new(StringArray::from(value.objects)), ]; UnionArray::try_new(union_fields(), Buffer::from_vec(value.type_ids).into(), None, children) } } +impl TryFrom for ListArray { + type Error = ArrowError; + + fn try_from(value: JsonUnion) -> Result { + let string_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(string_builder); + + for row in value.arrays { + if let Some(row) = row { + for elem in row { + list_builder.values().append_value(elem); + } + + list_builder.append(true); + } else { + list_builder.append(false); + } + } + + Ok(list_builder.finish()) + } +} + #[derive(Debug)] pub(crate) enum JsonUnionField { JsonNull, @@ -138,7 +150,7 @@ pub(crate) enum JsonUnionField { Int(i64), Float(f64), Str(String), - Array(String), + Array(Vec), Object(String), } @@ -196,8 +208,63 @@ impl From for ScalarValue { JsonUnionField::Bool(b) => Self::Boolean(Some(b)), JsonUnionField::Int(i) => Self::Int64(Some(i)), JsonUnionField::Float(f) => Self::Float64(Some(f)), - JsonUnionField::Str(s) | JsonUnionField::Array(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), + JsonUnionField::Array(a) => Self::Utf8(Some(a.join(","))), + JsonUnionField::Str(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), + } + } +} + +/// So we can do `collect::()` +impl FromIterator> for JsonUnion { + fn from_iter>>(iter: I) -> Self { + let inner = iter.into_iter(); + let (lower, upper) = inner.size_hint(); + let mut union = Self::new(upper.unwrap_or(lower)); + + for opt_field in inner { + if let Some(union_field) = opt_field { + union.push(union_field); + } else { + union.push_none(); + } } + union + } +} + +#[derive(Debug)] +pub(crate) struct JsonArrayField(pub(crate) Vec); + +impl From for ScalarValue { + fn from(JsonArrayField(elems): JsonArrayField) -> Self { + Self::List(Self::new_list_nullable( + &elems.into_iter().map(|e| Self::Utf8(Some(e))).collect::>(), + &DataType::Utf8, + )) + } +} + +impl From for JsonUnionField { + fn from(JsonArrayField(elems): JsonArrayField) -> Self { + JsonUnionField::Array(elems) + } +} + +impl FromIterator> for JsonUnion { + fn from_iter>>(iter: T) -> Self { + let inner = iter.into_iter(); + let (lower, upper) = inner.size_hint(); + let mut union = Self::new(upper.unwrap_or(lower)); + + for opt_field in inner { + if let Some(array_field) = opt_field { + union.push(array_field.into()); + } else { + union.push_none(); + } + } + + union } } @@ -281,7 +348,7 @@ mod test { Some(JsonUnionField::Int(42)), Some(JsonUnionField::Float(42.0)), Some(JsonUnionField::Str("foo".to_string())), - Some(JsonUnionField::Array("[42]".to_string())), + Some(JsonUnionField::Array(vec!["[42]".to_string()])), Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())), None, ]); diff --git a/src/json_get.rs b/src/json_get.rs index 8db3d8c..bb17446 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -93,7 +93,7 @@ fn build_union(jiter: &mut Jiter, peek: Peek) -> Result { let start = jiter.current_index(); diff --git a/src/json_get_array.rs b/src/json_get_array.rs new file mode 100644 index 0000000..1f9e1b7 --- /dev/null +++ b/src/json_get_array.rs @@ -0,0 +1,103 @@ +use std::any::Any; +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, ListArray}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::error::Result as DatafusionResult; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion::scalar::ScalarValue; +use jiter::Peek; + +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common_macros::make_udf_function; +use crate::common_union::{JsonArrayField, JsonUnion}; + +make_udf_function!( + JsonGetArray, + json_get_array, + json_data path, + r#"Get an arrow array value from a JSON string by its "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetArray { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonGetArray { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_get_array".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DatafusionResult { + return_type_check( + arg_types, + self.name(), + DataType::List(Field::new("item", DataType::Utf8, true).into()), + ) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DatafusionResult { + let to_array = |c: JsonUnion| { + let array: ListArray = c.try_into()?; + Ok(Arc::new(array) as ArrayRef) + }; + + invoke::( + args, + jiter_json_get_array, + to_array, + |i| i.map_or_else(|| ScalarValue::Null, Into::into), + true, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + Peek::Array => { + let mut peek_opt = jiter.known_array()?; + let mut elements = Vec::new(); + + while let Some(peek) = peek_opt { + let start = jiter.current_index(); + jiter.known_skip(peek)?; + let object_slice = jiter.slice_to_current(start); + let object_string = std::str::from_utf8(object_slice)?; + + elements.push(object_string.to_owned()); + + peek_opt = jiter.array_step()?; + } + + Ok(JsonArrayField(elements)) + } + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/lib.rs b/src/lib.rs index cb0f25a..118796c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ mod common_union; mod json_as_text; mod json_contains; mod json_get; +mod json_get_array; mod json_get_bool; mod json_get_float; mod json_get_int; @@ -26,6 +27,7 @@ pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; pub use crate::json_get::json_get; + pub use crate::json_get_array::json_get_array; pub use crate::json_get_bool::json_get_bool; pub use crate::json_get_float::json_get_float; pub use crate::json_get_int::json_get_int; @@ -39,6 +41,7 @@ pub mod udfs { pub use crate::json_as_text::json_as_text_udf; pub use crate::json_contains::json_contains_udf; pub use crate::json_get::json_get_udf; + pub use crate::json_get_array::json_get_array_udf; pub use crate::json_get_bool::json_get_bool_udf; pub use crate::json_get_float::json_get_float_udf; pub use crate::json_get_int::json_get_int_udf; @@ -60,6 +63,7 @@ pub mod udfs { pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ json_get::json_get_udf(), + json_get_array::json_get_array_udf(), json_get_bool::json_get_bool_udf(), json_get_float::json_get_float_udf(), json_get_int::json_get_int_udf(), diff --git a/tests/main.rs b/tests/main.rs index 567f8ce..250aae3 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -67,23 +67,23 @@ async fn test_json_get_union() { .unwrap(); let expected = [ - "+------------------+--------------------------------------+", - "| name | json_get(test.json_data,Utf8(\"foo\")) |", - "+------------------+--------------------------------------+", - "| object_foo | {str=abc} |", - "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=} |", - "| object_bar | {null=} |", - "| list_foo | {null=} |", - "| invalid_json | {null=} |", - "+------------------+--------------------------------------+", + "+------------------+--------------------------------------------------------------+", + "| name | json_get(test.json_data,Utf8(\"foo\")) |", + "+------------------+--------------------------------------------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } #[tokio::test] -async fn test_json_get_array() { +async fn test_json_get_array_index() { let sql = "select json_get('[1, 2, 3]', 2)"; let batches = run_query(sql).await.unwrap(); let (value_type, value_repr) = display_val(batches).await; @@ -319,17 +319,17 @@ async fn test_json_get_json() { .unwrap(); let expected = [ - "+------------------+-------------------------------------------+", - "| name | json_get_json(test.json_data,Utf8(\"foo\")) |", - "+------------------+-------------------------------------------+", - "| object_foo | \"abc\" |", - "| object_foo_array | [1] |", - "| object_foo_obj | {} |", - "| object_foo_null | null |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+-------------------------------------------+", + "+------------------+------------------------------------------------------+", + "| name | json_get_json(test.json_data,Utf8(\"foo\")) |", + "+------------------+------------------------------------------------------+", + "| object_foo | \"abc\" |", + "| object_foo_array | [1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}] |", + "| object_foo_obj | {} |", + "| object_foo_null | null |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -481,6 +481,27 @@ async fn test_json_length_vec() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_json_get_array() { + let sql = r"select name, unnest(json_get_array(json_data, 'foo')) from test"; + let batches = run_query(sql).await.unwrap(); + + let expected = [ + "+------------------+----------------------------------------------------+", + "| name | UNNEST(json_get_array(test.json_data,Utf8(\"foo\"))) |", + "+------------------+----------------------------------------------------+", + "| object_foo_array | 1 |", + "| object_foo_array | true |", + "| object_foo_array | {\"nested_foo\": \"baz\", \"nested_bar\": null} |", + "+------------------+----------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &batches); + + let batches = run_query_large(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_no_args() { let err = run_query(r"select json_len()").await.unwrap_err(); @@ -738,17 +759,17 @@ async fn test_arrow() { let batches = run_query("select name, json_data->'foo' from test").await.unwrap(); let expected = [ - "+------------------+-------------------------------+", - "| name | test.json_data -> Utf8(\"foo\") |", - "+------------------+-------------------------------+", - "| object_foo | {str=abc} |", - "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=} |", - "| object_bar | {null=} |", - "| list_foo | {null=} |", - "| invalid_json | {null=} |", - "+------------------+-------------------------------+", + "+------------------+--------------------------------------------------------------+", + "| name | test.json_data -> Utf8(\"foo\") |", + "+------------------+--------------------------------------------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -770,17 +791,17 @@ async fn test_long_arrow() { let batches = run_query("select name, json_data->>'foo' from test").await.unwrap(); let expected = [ - "+------------------+--------------------------------+", - "| name | test.json_data ->> Utf8(\"foo\") |", - "+------------------+--------------------------------+", - "| object_foo | abc |", - "| object_foo_array | [1] |", - "| object_foo_obj | {} |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+--------------------------------+", + "+------------------+------------------------------------------------------+", + "| name | test.json_data ->> Utf8(\"foo\") |", + "+------------------+------------------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | [1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}] |", + "| object_foo_obj | {} |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index dab3d4e..c220bc7 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -25,7 +25,10 @@ async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result